diff --git a/mindspore-lite/python/api/__init__.py b/mindspore-lite/python/api/__init__.py index 28d875529101379aa338e7f9217e186328a14b96..8951b2c0106c4008282c61ea7704b871c2153206 100644 --- a/mindspore-lite/python/api/__init__.py +++ b/mindspore-lite/python/api/__init__.py @@ -24,7 +24,15 @@ from importlib.abc import MetaPathFinder from mindspore_lite.version import __version__ from mindspore_lite.context import Context from mindspore_lite.converter import FmkType, Converter -from mindspore_lite.model import ModelType, Model, ModelParallelRunner, ModelGroup, ModelGroupFlag, MultiModelRunner, ModelExecutor +from mindspore_lite.model import ( + ModelType, + Model, + ModelParallelRunner, + ModelGroup, + ModelGroupFlag, + MultiModelRunner, + ModelExecutor +) from mindspore_lite.tensor import DataType, Format, Tensor, TensorMeta from mindspore_lite.lite_infer import LiteInfer from mindspore_lite import lite_infer diff --git a/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc b/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc index 8a130bd8440dc171565f44d06f2a4f74f07c2767..cfff4d78f4576d622e8a0b1e7d21a98672e469b4 100644 --- a/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc +++ b/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc @@ -52,66 +52,107 @@ void GetSplitNode(const std::shared_ptr ¶m, std::string *spli } } +size_t FindTopLevelComma(const std::string &block) { + int inner_bracket_level = 0; + for (size_t j = 1; j < block.length() - 1; j++) { + if (block[j] == '[') { + inner_bracket_level++; + } else if (block[j] == ']') { + inner_bracket_level--; + } else if (block[j] == ',' && inner_bracket_level == 0) { + return j; + } + } + return std::string::npos; +} + +size_t FindTopLevelComma(const std::string &block) { + int inner_bracket_level = 0; + for (size_t j = 1; j < block.length() - 1; j++) { + if (block[j] == '[') { + inner_bracket_level++; + } else if (block[j] == ']') { + inner_bracket_level--; + } else if (block[j] == ',' && inner_bracket_level == 0) { + return j; + } + } + return std::string::npos; +} + +bool ProcessBlock(const std::string &block, std::set *op_set, + const std::shared_ptr ¶m) { + size_t split_pos = FindTopLevelComma(block); + if (split_pos == std::string::npos) { + MS_LOG(ERROR) << "Invalid block format: " << block; + return false; + } + + std::string first_part = block.substr(1, split_pos - 1); + std::string second_part = block.substr(split_pos + 1, block.length() - split_pos - 2); + std::vector first_vector = ParseInnerList(first_part); + std::vector second_vector = ParseInnerList(second_part); + if (second_vector.empty()) { + MS_LOG(ERROR) << "Current subgraph output name is empty!"; + return false; + } + + op_set->insert(first_vector.begin(), first_vector.end()); + op_set->insert(second_vector.begin(), second_vector.end()); + param->splitGraphCfg.subgraph_input_output.emplace_back(first_vector, second_vector); + + return true; +} + +size_t FindMatchingClosingBracket(const std::string &input, size_t start_pos) { + int bracket_level = 0; + for (size_t i = start_pos; i < input.length(); i++) { + if (input[i] == '[') { + bracket_level++; + } else if (input[i] == ']') { + bracket_level--; + if (bracket_level == 0) { + return i; + } + } + } + return std::string::npos; +} + +size_t FindNextOpeningBracket(const std::string &input, size_t start_pos) { + while (start_pos < input.length() && input[start_pos] != '[') { + start_pos++; + } + return start_pos; +} + STATUS GraphPllitParamParser::ParseGraphSplitCfg(const std::shared_ptr ¶m) { MS_CHECK_TRUE_MSG(param != nullptr, RET_ERROR, "param is nullptr!"); - std::string split_node_str = ""; + std::string split_node_str; GetSplitNode(param, &split_node_str); - MS_CHECK_TRUE_RET(!split_node_str.empty(), RET_OK); - auto input = split_node_str; + if (split_node_str.empty()) { + return RET_OK; + } std::set op_set; size_t pos = 0; - while (pos < input.length()) { - while (pos < input.length() && input[pos] != '[') { - pos++; - } - if (pos >= input.length()) { + const size_t input_length = split_node_str.length(); + while (pos < input_length) { + pos = FindNextOpeningBracket(split_node_str, pos); + if (pos >= input_length) { break; } - int bracket_level = 0; - size_t start_pos = pos; - for (size_t i = start_pos; i < input.length(); i++) { - if (input[i] == '[') { - bracket_level++; - } else if (input[i] == ']') { - bracket_level--; - } - if (bracket_level == 0 && i > start_pos) { - std::string block = input.substr(start_pos, i - start_pos + 1); - int inner_bracket_level = 0; - size_t split_pos = std::string::npos; - for (size_t j = 1; j < block.length() - 1; j++) { - if (block[j] == '[') { - inner_bracket_level++; - } else if (block[j] == ']') { - inner_bracket_level--; - } - if (block[j] == ',' && inner_bracket_level == 0) { - split_pos = j; - break; - } - } - if (split_pos != std::string::npos) { - std::string first_part = block.substr(1, split_pos - 1); - std::string second_part = block.substr(split_pos + 1, block.length() - (split_pos + 1) - 1); - std::vector first_vector = ParseInnerList(first_part); - std::vector second_vector = ParseInnerList(second_part); - for (auto s : first_vector) { - op_set.insert(s); - } - for (auto s : second_vector) { - op_set.insert(s); - } - MS_CHECK_TRUE_MSG(!second_vector.empty(), lite::RET_ERROR, "Current subgraph output name is empty!"); - param->splitGraphCfg.subgraph_input_output.emplace_back(first_vector, second_vector); - } - pos = i + 1; - break; - } + size_t end_pos = FindMatchingClosingBracket(split_node_str, pos); + if (end_pos == std::string::npos) { + MS_LOG(ERROR) << "Unmatched brackets in split config starting at position " << pos; + return lite::RET_ERROR; } + std::string block = split_node_str.substr(pos, end_pos - pos + 1); + if (!ProcessBlock(block, &op_set, param)) { + return lite::RET_ERROR; + } + pos = end_pos + 1; } - for (auto s : op_set) { - param->splitGraphCfg.split_node_names.push_back(s); - } + param->splitGraphCfg.split_node_names.assign(op_set.begin(), op_set.end()); return RET_OK; } } // namespace lite diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc index aee53e9550105d957b5468cd3c20500dbe54e555..1124747b92d2dedd34246514b19fe0bc7aff0fb6 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -130,7 +130,7 @@ int MatMulBaseInt8Coder::MallocQuantParam() { std::vector weight_quant_params = filter_tensor_->quant_params(); int col = filter_tensor_->shape().front(); filter_per_channel_ = (weight_quant_params.size() > 1); - weight_quant_num_ = filter_per_channel_ ? col : 1; + weight_quant_num_ = filter_per_channel_ ? static_cast(col) : 1; quant_.filter_scale_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(float))); MS_CHECK_PTR(quant_.filter_scale_); quant_.filter_zp_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 314b50e0f123cfa4f25639bbd47ea1707f244b44..95cbc33328af6b9443ddbbac8c73bd05a84eae60 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -41,34 +41,62 @@ bool CheckDilations(const onnx::AttributeProto &onnx_node_attr) { return true; } +void ParseKernelSize(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + std::vector *kernels, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize2) { + kernels->push_back(onnx_node_attr.ints(0)); + kernels->push_back(onnx_node_attr.ints(kIndex1)); + prim->set_kernel_size(*kernels); + } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { + *is_3d = true; + kernels->push_back(onnx_node_attr.ints(0)); + kernels->push_back(onnx_node_attr.ints(kIndex1)); + kernels->push_back(onnx_node_attr.ints(kIndex2)); + prim->AddAttr("kernel_size", api::MakeValue>(*kernels)); + } +} + +void ParseStrides(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + std::vector *strides, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize2) { + strides->push_back(onnx_node_attr.ints(0)); + strides->push_back(onnx_node_attr.ints(kIndex1)); + } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { + *is_3d = true; + strides->push_back(onnx_node_attr.ints(0)); + strides->push_back(onnx_node_attr.ints(kIndex1)); + strides->push_back(onnx_node_attr.ints(kIndex2)); + } +} + +void ParsePads(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + std::vector *pads, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize4) { + pads->push_back(onnx_node_attr.ints(0)); + pads->push_back(onnx_node_attr.ints(kIndex2)); + pads->push_back(onnx_node_attr.ints(kIndex1)); + pads->push_back(onnx_node_attr.ints(kIndex3)); + } else if (onnx_node_attr.ints_size() == kNumShapeSize6) { + *is_3d = true; + pads->push_back(onnx_node_attr.ints(0)); + pads->push_back(onnx_node_attr.ints(kIndex3)); + pads->push_back(onnx_node_attr.ints(kIndex1)); + pads->push_back(onnx_node_attr.ints(kIndex4)); + pads->push_back(onnx_node_attr.ints(kIndex2)); + pads->push_back(onnx_node_attr.ints(kIndex5)); + } +} + bool ParseAttrs(const onnx::NodeProto &onnx_node, const std::unique_ptr &prim, std::vector *kernels, std::vector *strides, std::vector *pads, mindspore::RoundMode *round_mode, bool *is_3d) { for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "kernel_shape") { - if (onnx_node_attr.ints_size() == kNumShapeSize2) { - kernels->push_back(onnx_node_attr.ints(0)); - kernels->push_back(onnx_node_attr.ints(kIndex1)); - prim->set_kernel_size(*kernels); - } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { - *is_3d = true; - kernels->push_back(onnx_node_attr.ints(0)); - kernels->push_back(onnx_node_attr.ints(kIndex1)); - kernels->push_back(onnx_node_attr.ints(kIndex2)); - prim->AddAttr("kernel_size", api::MakeValue>(*kernels)); - } + ParseKernelSize(prim, onnx_node_attr, kernels, is_3d); } if (attribute_name == "strides") { - if (onnx_node_attr.ints_size() == kNumShapeSize2) { - strides->push_back(onnx_node_attr.ints(0)); - strides->push_back(onnx_node_attr.ints(kIndex1)); - } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { - *is_3d = true; - strides->push_back(onnx_node_attr.ints(0)); - strides->push_back(onnx_node_attr.ints(kIndex1)); - strides->push_back(onnx_node_attr.ints(kIndex2)); - } + ParseStrides(prim, onnx_node_attr, strides, is_3d); } if (attribute_name == "auto_pad") { if (onnx_node_attr.s() == "SAME_UPPER") { @@ -79,20 +107,7 @@ bool ParseAttrs(const onnx::NodeProto &onnx_node, const std::unique_ptrpush_back(onnx_node_attr.ints(0)); - pads->push_back(onnx_node_attr.ints(kIndex2)); - pads->push_back(onnx_node_attr.ints(kIndex1)); - pads->push_back(onnx_node_attr.ints(kIndex3)); - } else if (onnx_node_attr.ints_size() == kNumShapeSize6) { - *is_3d = true; - pads->push_back(onnx_node_attr.ints(0)); - pads->push_back(onnx_node_attr.ints(kIndex3)); - pads->push_back(onnx_node_attr.ints(kIndex1)); - pads->push_back(onnx_node_attr.ints(kIndex4)); - pads->push_back(onnx_node_attr.ints(kIndex2)); - pads->push_back(onnx_node_attr.ints(kIndex5)); - } + ParsePads(prim, onnx_node_attr, pads, is_3d) } if (attribute_name == "ceil_mode") { *round_mode = (onnx_node_attr.i() == 0) ? mindspore::RoundMode::FLOOR : mindspore::RoundMode::CEIL; diff --git a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc index 6ac2cab14e5539ea8b01b5357e05eaa47b5511bf..2dc31d2f4edba78a2f3d4685db3745ff8ff86809 100644 --- a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc @@ -168,20 +168,28 @@ bool IsParentNode(const AnfNodePtr &input_node, const AnfNodePtr &output_node) { return false; } -STATUS GetSubgraphStopNodes(const std::vector> &boundaries, const size_t ¤t_index, +bool HasParentInBoundary(const std::vector ¤t_stop_nodes, + const std::vector &output_nodes) { + return std::any_of(output_nodes.begin(), output_nodes.end(), [&](const AnfNodePtr &output_node) { + return std::any_of(current_stop_nodes.begin(), current_stop_nodes.end(), + [&](const AnfNodePtr &stop_node) { return IsParentNode(stop_node, output_node); }); + }); +} + +STATUS GetSubgraphStopNodes(const std::vector> &boundaries, size_t current_index, const std::vector &output_nodes, std::vector *stop_nodes) { - for (int i = current_index - 1; i >= 0; i--) { - auto current_stop_nodes = boundaries[i]; - for (size_t j = 0; j < output_nodes.size(); j++) { - for (size_t k = 0; k < current_stop_nodes.size(); k++) { - if (IsParentNode(current_stop_nodes[k], output_nodes[j])) { - stop_nodes->insert(stop_nodes->end(), current_stop_nodes.begin(), current_stop_nodes.end()); - return lite::RET_OK; - } - } + if (stop_nodes == nullptr) { + MS_LOG(ERROR) << "Stop nodes pointer is null"; + return lite::RET_NULL_PTR; + } + for (int i = static_cast(current_index) - 1; i >= 0; --i) { + const auto ¤t_stop_nodes = boundaries[i]; + if (HasParentInBoundary(current_stop_nodes, output_nodes)) { + stop_nodes->insert(stop_nodes->end(), current_stop_nodes.begin(), current_stop_nodes.end()); + return lite::RET_OK; } } - MS_LOG(ERROR) << "Can not found prenode for current output nodes"; + MS_LOG(ERROR) << "Can not find prenode for current output nodes"; return lite::RET_ERROR; } diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc index ada4fe63f22de09fa8e69a5a156dc8074bed53bf..9df65b1128ddb61278ae87dcb5f62ba46dfeed55 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc @@ -55,6 +55,16 @@ constexpr float kInitOne = 1.0; constexpr size_t kInitBatchSize = 1; constexpr size_t kMaxConfigLen = 1e6; constexpr uint16_t kFloatOne = 15360; +bool IsMatMulNode(const AnfNodePtr &node) { + return mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulFusion) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul); +} + +bool IsConv2DNode(const AnfNodePtr &node) { + return mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2DFusion); +} } // namespace template @@ -622,93 +632,152 @@ lite::STATUS InsertVariableNodePass::RecordParameterVariableName( lite::STATUS InsertVariableNodePass::BuildVariableNode(const std::shared_ptr ¶m, FuncGraphPtr func_graph, std::vector *const_names) { - MS_CHECK_TRUE_RET(func_graph != nullptr, RET_ERROR); - std::string variable_weights_file = ""; + if (func_graph == nullptr) { + return RET_ERROR; + } + std::string variable_weights_file; int32_t max_weight_batch = 1; InitWeightParam(param, &variable_weights_file, &max_weight_batch); - MS_CHECK_TRUE_RET(variable_weights_file != "", RET_OK); + if (variable_weights_file.empty()) { + return RET_OK; + } bool has_alpha = false; std::map> variable_nodes; std::unordered_map node_name_map; std::vector node_name_list; auto ret = ParseInsertNode(variable_weights_file, &variable_nodes, &node_name_map, &node_name_list, &has_alpha); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "ParseInsertNode failed!"); - uint32_t matched_num = 0; + if (ret != RET_OK) { + MS_LOG(ERROR) << "ParseInsertNode failed!"; + return ret; + } auto node_list = TopoSort(func_graph->get_return()); + uint32_t matched_num = 0; for (auto &node : node_list) { - MS_CHECK_TRUE_RET(node != nullptr, false); - auto node_name = node->fullname_with_scope(); - size_t last_slash_pos = node_name.find_last_of('/'); - std::string search_key = ""; + if (node == nullptr) { + return RET_ERROR; + } if (utils::isa(node)) { - search_key = node_name; - if (variable_nodes.find(search_key) == variable_nodes.end()) { - continue; - } - auto parameter = node->cast(); - if (parameter == nullptr || !parameter->has_default()) { - continue; - } - ret = RecordParameterVariableName(func_graph, parameter, search_key, false, &node_name_map); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record parameter variable name failed!"); - } else if (utils::isa(node)) { - if (last_slash_pos != std::string::npos) { - search_key = node_name.substr(0, last_slash_pos); - } else { - MS_LOG(INFO) << "Not found last slash, Cnode name:" << node->fullname_with_scope() << "!"; - continue; - } - if (variable_nodes.find(search_key) == variable_nodes.end()) { - continue; - } - auto cnode = utils::cast(node); - MS_CHECK_TRUE_RET(cnode != nullptr, false); - if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulFusion) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), true, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, true, &node_name_map); - } else { - ret = InsertVariableNodeForMatmul(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2DFusion)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), false, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, false, &node_name_map); - } else { - ret = InsertVariableNodeForConv(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else { - continue; - } - } else { + matched_num += ProcessParameterNode(func_graph, node, variable_nodes, &node_name_map) ? 1 : 0; + continue; + } + if (utils::isa(node)) { + matched_num += + ProcessCNode(func_graph, node, variable_nodes, &node_name_map, has_alpha, max_weight_batch) ? 1 : 0; continue; } - matched_num++; } if (matched_num != variable_nodes.size()) { - MS_LOG(ERROR) << "matched num:" << matched_num << " != all node num:" << variable_nodes.size() << "!"; + MS_LOG(ERROR) << "Matched num: " << matched_num << " != all node num: " << variable_nodes.size() << "!"; return RET_ERROR; } - for (auto s : node_name_list) { - if (node_name_map.find(s) == node_name_map.end()) { - continue; + + for (const auto &name : node_name_list) { + auto it = node_name_map.find(name); + if (it != node_name_map.end()) { + const_names->push_back(it->second); } - const_names->push_back(node_name_map[s]); } return RET_OK; } +bool InsertVariableNodePass::ProcessParameterNode(FuncGraphPtr func_graph, const AnfNodePtr &node, + const std::map> &variable_nodes, + std::unordered_map *node_name_map) { + auto node_name = node->fullname_with_scope(); + if (variable_nodes.find(node_name) == variable_nodes.end()) { + return false; + } + auto parameter = node->cast(); + if (parameter == nullptr || !parameter->has_default()) { + return false; + } + auto ret = RecordParameterVariableName(func_graph, parameter, node_name, false, node_name_map); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Record parameter variable name failed!"; + return false; + } + return true; +} + +bool InsertVariableNodePass::ProcessCNode(FuncGraphPtr func_graph, const AnfNodePtr &node, + const std::map> &variable_nodes, + std::unordered_map *node_name_map, bool has_alpha, + int32_t max_weight_batch) { + auto node_name = node->fullname_with_scope(); + size_t last_slash_pos = node_name.find_last_of('/'); + std::string search_key; + if (last_slash_pos == std::string::npos) { + MS_LOG(INFO) << "Not found last slash, Cnode name:" << node->fullname_with_scope() << "!"; + return false; + } + search_key = node_name.substr(0, last_slash_pos); + if (variable_nodes.find(search_key) == variable_nodes.end()) { + return false; + } + auto cnode = utils::cast(node); + if (cnode == nullptr) { + return false; + } + if (IsMatMulNode(node)) { + return ProcessMatMulNode(func_graph, node, cnode, search_key, variable_nodes.at(search_key), node_name_map, + has_alpha, max_weight_batch); + } + if (IsConv2DNode(node)) { + return ProcessConv2DNode(func_graph, node, cnode, search_key, variable_nodes.at(search_key), node_name_map, + has_alpha, max_weight_batch); + } + return false; +} + +bool InsertVariableNodePass::ProcessMatMulNode(FuncGraphPtr func_graph, const AnfNodePtr &node, const CNodePtr &cnode, + const std::string &search_key, const std::vector &node_info, + std::unordered_map *node_name_map, + bool has_alpha, int32_t max_weight_batch) { + bool replace_origin = false; + auto ret = CheckOnlyReplace(cnode, node_info, true, &replace_origin); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CheckOnlyReplace failed!"; + return false; + } + + if (replace_origin) { + ret = RecordVariableName(func_graph, cnode, search_key, true, node_name_map); + } else { + ret = InsertVariableNodeForMatmul(node, cnode, func_graph, node_info, node_name_map, has_alpha, max_weight_batch); + } + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Record variable name failed!"; + return false; + } + return true; +} + +bool InsertVariableNodePass::ProcessConv2DNode(FuncGraphPtr func_graph, const AnfNodePtr &node, const CNodePtr &cnode, + const std::string &search_key, const std::vector &node_info, + std::unordered_map *node_name_map, + bool has_alpha, int32_t max_weight_batch) { + bool replace_origin = false; + auto ret = CheckOnlyReplace(cnode, node_info, false, &replace_origin); + if (ret != RET_OK) { + MS_LOG(ERROR) << "CheckOnlyReplace failed!"; + return false; + } + + if (replace_origin) { + ret = RecordVariableName(func_graph, cnode, search_key, false, node_name_map); + } else { + ret = InsertVariableNodeForConv(node, cnode, func_graph, node_info, node_name_map, has_alpha, max_weight_batch); + } + + if (ret != RET_OK) { + MS_LOG(ERROR) << "Record variable name failed!"; + return false; + } + + return true; +} + bool InsertVariableNodePass::Run(const FuncGraphPtr &graph) { if (BuildVariableNode(param_, graph, &(param_->const_names)) != RET_OK) { return false; diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h index ba623d6771f734f60bdde58336371fe7d172cd78..6a5c2ea06f1b73ef7e19375fcd260e1622e0b53e 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h @@ -38,6 +38,21 @@ class InsertVariableNodePass : public Pass { private: lite::STATUS BuildVariableNode(const std::shared_ptr ¶m, FuncGraphPtr func_graph, std::vector *const_names); + bool ProcessParameterNode(FuncGraphPtr func_graph, const AnfNodePtr &node, + const std::map> &variable_nodes, + std::unordered_map *node_name_map); + bool ProcessCNode(FuncGraphPtr func_graph, const AnfNodePtr &node, + const std::map> &variable_nodes, + std::unordered_map *node_name_map, bool has_alpha, + int32_t max_weight_batch); + bool ProcessMatMulNode(FuncGraphPtr func_graph, const AnfNodePtr &node, const CNodePtr &cnode, + const std::string &search_key, const std::vector &node_info, + std::unordered_map *node_name_map, bool has_alpha, + int32_t max_weight_batch); + bool ProcessConv2DNode(FuncGraphPtr func_graph, const AnfNodePtr &node, const CNodePtr &cnode, + const std::string &search_key, const std::vector &node_info, + std::unordered_map *node_name_map, bool has_alpha, + int32_t max_weight_batch); lite::STATUS InsertVariableNodeForMatmul(const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, const std::vector &up_shape, std::unordered_map *node_name_map, bool has_alpha,