From cf1b7ee4fceeb33de999778adf7f4fff1737f134 Mon Sep 17 00:00:00 2001 From: jianghui58 Date: Tue, 16 Sep 2025 16:34:07 +0800 Subject: [PATCH] code check fix --- mindspore-lite/python/api/__init__.py | 10 +- .../config_parser/graph_split_param_parser.cc | 139 +++++++----- .../nnacl/int8/conv2d_3x3_int8_coder.cc | 9 +- .../nnacl/int8/matmul_base_int8_coder.cc | 2 +- .../converter/parser/onnx/onnx_pool_parser.cc | 83 ++++--- .../optimizer/fusion/graph_split_pass.cc | 34 ++- .../optimizer/graph/add_variable_node_pass.cc | 211 ++++++++++++------ .../optimizer/graph/add_variable_node_pass.h | 19 ++ 8 files changed, 342 insertions(+), 165 deletions(-) diff --git a/mindspore-lite/python/api/__init__.py b/mindspore-lite/python/api/__init__.py index 28d87552..8951b2c0 100644 --- a/mindspore-lite/python/api/__init__.py +++ b/mindspore-lite/python/api/__init__.py @@ -24,7 +24,15 @@ from importlib.abc import MetaPathFinder from mindspore_lite.version import __version__ from mindspore_lite.context import Context from mindspore_lite.converter import FmkType, Converter -from mindspore_lite.model import ModelType, Model, ModelParallelRunner, ModelGroup, ModelGroupFlag, MultiModelRunner, ModelExecutor +from mindspore_lite.model import ( + ModelType, + Model, + ModelParallelRunner, + ModelGroup, + ModelGroupFlag, + MultiModelRunner, + ModelExecutor +) from mindspore_lite.tensor import DataType, Format, Tensor, TensorMeta from mindspore_lite.lite_infer import LiteInfer from mindspore_lite import lite_infer diff --git a/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc b/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc index 8a130bd8..f662ec22 100644 --- a/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc +++ b/mindspore-lite/tools/converter/config_parser/graph_split_param_parser.cc @@ -52,66 +52,103 @@ void GetSplitNode(const std::shared_ptr ¶m, std::string *spli } } +size_t FindTopLevelComma(const std::string &block) { + int inner_bracket_level = 0; + for (size_t j = 1; j < block.length() - 1; j++) { + if (block[j] == '[') { + inner_bracket_level++; + } else if (block[j] == ']') { + inner_bracket_level--; + } else if (block[j] == ',' && inner_bracket_level == 0) { + return j; + } + } + return std::string::npos; +} + +size_t FindTopLevelComma(const std::string &block) { + int inner_bracket_level = 0; + for (size_t j = 1; j < block.length() - 1; j++) { + if (block[j] == '[') { + inner_bracket_level++; + } else if (block[j] == ']') { + inner_bracket_level--; + } else if (block[j] == ',' && inner_bracket_level == 0) { + return j; + } + } + return std::string::npos; +} +bool ProcessBlock(const std::string &block, std::set &op_set, + const std::shared_ptr ¶m) { + size_t split_pos = FindTopLevelComma(block); + if (split_pos == std::string::npos) { + MS_LOG(ERROR) << "Invalid block format: " << block; + return false; + } + + std::string first_part = block.substr(1, split_pos - 1); + std::string second_part = block.substr(split_pos + 1, block.length() - split_pos - 2); + std::vector first_vector = ParseInnerList(first_part); + std::vector second_vector = ParseInnerList(second_part); + if (second_vector.empty()) { + MS_LOG(ERROR) << "Current subgraph output name is empty!"; + return false; + } + + op_set.insert(first_vector.begin(), first_vector.end()); + op_set.insert(second_vector.begin(), second_vector.end()); + param->splitGraphCfg.subgraph_input_output.emplace_back(first_vector, second_vector); + + return true; +} +size_t FindMatchingClosingBracket(const std::string &input, size_t start_pos) { + int bracket_level = 0; + for (size_t i = start_pos; i < input.length(); i++) { + if (input[i] == '[') { + bracket_level++; + } else if (input[i] == ']') { + bracket_level--; + if (bracket_level == 0) { + return i; + } + } + } + return std::string::npos; // 未找到匹配的] +} +size_t FindNextOpeningBracket(const std::string &input, size_t start_pos) { + while (start_pos < input.length() && input[start_pos] != '[') { + start_pos++; + } + return start_pos; +} + STATUS GraphPllitParamParser::ParseGraphSplitCfg(const std::shared_ptr ¶m) { MS_CHECK_TRUE_MSG(param != nullptr, RET_ERROR, "param is nullptr!"); - std::string split_node_str = ""; + std::string split_node_str; GetSplitNode(param, &split_node_str); - MS_CHECK_TRUE_RET(!split_node_str.empty(), RET_OK); - auto input = split_node_str; + if (split_node_str.empty()) { + return RET_OK; + } std::set op_set; size_t pos = 0; - while (pos < input.length()) { - while (pos < input.length() && input[pos] != '[') { - pos++; - } - if (pos >= input.length()) { + const size_t input_length = split_node_str.length(); + while (pos < input_length) { + pos = FindNextOpeningBracket(split_node_str, pos); + if (pos >= input_length) { break; } - int bracket_level = 0; - size_t start_pos = pos; - for (size_t i = start_pos; i < input.length(); i++) { - if (input[i] == '[') { - bracket_level++; - } else if (input[i] == ']') { - bracket_level--; - } - if (bracket_level == 0 && i > start_pos) { - std::string block = input.substr(start_pos, i - start_pos + 1); - int inner_bracket_level = 0; - size_t split_pos = std::string::npos; - for (size_t j = 1; j < block.length() - 1; j++) { - if (block[j] == '[') { - inner_bracket_level++; - } else if (block[j] == ']') { - inner_bracket_level--; - } - if (block[j] == ',' && inner_bracket_level == 0) { - split_pos = j; - break; - } - } - if (split_pos != std::string::npos) { - std::string first_part = block.substr(1, split_pos - 1); - std::string second_part = block.substr(split_pos + 1, block.length() - (split_pos + 1) - 1); - std::vector first_vector = ParseInnerList(first_part); - std::vector second_vector = ParseInnerList(second_part); - for (auto s : first_vector) { - op_set.insert(s); - } - for (auto s : second_vector) { - op_set.insert(s); - } - MS_CHECK_TRUE_MSG(!second_vector.empty(), lite::RET_ERROR, "Current subgraph output name is empty!"); - param->splitGraphCfg.subgraph_input_output.emplace_back(first_vector, second_vector); - } - pos = i + 1; - break; - } + if (end_pos == std::string::npos) { + MS_LOG(ERROR) << "Unmatched brackets in split config: " << split_node_str.substr(pos); + return lite::RET_ERROR; } + std::string block = split_node_str.substr(pos, end_pos - pos + 1); + if (!ProcessBlock(block, op_set, param)) { + return lite::RET_ERROR; + } + pos = end_pos + 1; } - for (auto s : op_set) { - param->splitGraphCfg.split_node_names.push_back(s); - } + param->splitGraphCfg.split_node_names.assign(op_set.begin(), op_set.end()); return RET_OK; } } // namespace lite diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc index 8e8b594d..d2687521 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/conv2d_3x3_int8_coder.cc @@ -110,7 +110,14 @@ int Conv2D3x3Int8Coder::InitTmpBuffer(CoderContext *const context) { } void Conv2D3x3Int8Coder::ConfigInputOutput() { output_tensor_->set_format(mindspore::NHWC); } - +bool Conv2DConditionalJudgement(const int input_channel, const int output_channel, const int input_h, + const int input_w) { + static const std::set> forbidden_combinations = { + {kNumber1, kNumber16, kNumber25, kNumber24}, + {kNumber16, kNumber32, kNumber12, kNumber12}, + {kNumber32, kNumber64, kNumber6, kNumber6}}; + return !forbidden_combinations.count(std::make_tuple(input_channel, output_channel, input_h, input_w)); +} int Conv2D3x3Int8Coder::Prepare(CoderContext *const context) { MS_CHECK_RET_CODE(Conv2DBaseCoder::Init(), "ConvolutionBase init failed."); conv_param_->thread_num_ = thread_num_; diff --git a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc index aee53e95..1124747b 100644 --- a/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc +++ b/mindspore-lite/tools/converter/micro/coder/opcoders/nnacl/int8/matmul_base_int8_coder.cc @@ -130,7 +130,7 @@ int MatMulBaseInt8Coder::MallocQuantParam() { std::vector weight_quant_params = filter_tensor_->quant_params(); int col = filter_tensor_->shape().front(); filter_per_channel_ = (weight_quant_params.size() > 1); - weight_quant_num_ = filter_per_channel_ ? col : 1; + weight_quant_num_ = filter_per_channel_ ? static_cast(col) : 1; quant_.filter_scale_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(float))); MS_CHECK_PTR(quant_.filter_scale_); quant_.filter_zp_ = reinterpret_cast(malloc(weight_quant_num_ * sizeof(int32_t))); diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc index 314b50e0..b7385447 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_pool_parser.cc @@ -41,34 +41,62 @@ bool CheckDilations(const onnx::AttributeProto &onnx_node_attr) { return true; } +void ParseKernelSize(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + td::vector *kernels, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize2) { + kernels->push_back(onnx_node_attr.ints(0)); + kernels->push_back(onnx_node_attr.ints(kIndex1)); + prim->set_kernel_size(*kernels); + } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { + *is_3d = true; + kernels->push_back(onnx_node_attr.ints(0)); + kernels->push_back(onnx_node_attr.ints(kIndex1)); + kernels->push_back(onnx_node_attr.ints(kIndex2)); + prim->AddAttr("kernel_size", api::MakeValue>(*kernels)); + } +} + +void ParseStrides(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + td::vector *strides, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize2) { + strides->push_back(onnx_node_attr.ints(0)); + strides->push_back(onnx_node_attr.ints(kIndex1)); + } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { + *is_3d = true; + strides->push_back(onnx_node_attr.ints(0)); + strides->push_back(onnx_node_attr.ints(kIndex1)); + strides->push_back(onnx_node_attr.ints(kIndex2)); + } +} + +void ParsePads(const std::unique_ptr &prim, const onnx::AttributeProto &onnx_node_attr, + td::vector *pads, bool *is_3d) { + if (onnx_node_attr.ints_size() == kNumShapeSize4) { + pads->push_back(onnx_node_attr.ints(0)); + pads->push_back(onnx_node_attr.ints(kIndex2)); + pads->push_back(onnx_node_attr.ints(kIndex1)); + pads->push_back(onnx_node_attr.ints(kIndex3)); + } else if (onnx_node_attr.ints_size() == kNumShapeSize6) { + *is_3d = true; + pads->push_back(onnx_node_attr.ints(0)); + pads->push_back(onnx_node_attr.ints(kIndex3)); + pads->push_back(onnx_node_attr.ints(kIndex1)); + pads->push_back(onnx_node_attr.ints(kIndex4)); + pads->push_back(onnx_node_attr.ints(kIndex2)); + pads->push_back(onnx_node_attr.ints(kIndex5)); + } +} + bool ParseAttrs(const onnx::NodeProto &onnx_node, const std::unique_ptr &prim, std::vector *kernels, std::vector *strides, std::vector *pads, mindspore::RoundMode *round_mode, bool *is_3d) { for (const auto &onnx_node_attr : onnx_node.attribute()) { const auto &attribute_name = onnx_node_attr.name(); if (attribute_name == "kernel_shape") { - if (onnx_node_attr.ints_size() == kNumShapeSize2) { - kernels->push_back(onnx_node_attr.ints(0)); - kernels->push_back(onnx_node_attr.ints(kIndex1)); - prim->set_kernel_size(*kernels); - } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { - *is_3d = true; - kernels->push_back(onnx_node_attr.ints(0)); - kernels->push_back(onnx_node_attr.ints(kIndex1)); - kernels->push_back(onnx_node_attr.ints(kIndex2)); - prim->AddAttr("kernel_size", api::MakeValue>(*kernels)); - } + ParseKernelSize(prim, onnx_node_attr, kernels, is_3d); } if (attribute_name == "strides") { - if (onnx_node_attr.ints_size() == kNumShapeSize2) { - strides->push_back(onnx_node_attr.ints(0)); - strides->push_back(onnx_node_attr.ints(kIndex1)); - } else if (onnx_node_attr.ints_size() == kNumShapeSize3) { - *is_3d = true; - strides->push_back(onnx_node_attr.ints(0)); - strides->push_back(onnx_node_attr.ints(kIndex1)); - strides->push_back(onnx_node_attr.ints(kIndex2)); - } + ParseStrides(prim, onnx_node_attr, strides, is_3d); } if (attribute_name == "auto_pad") { if (onnx_node_attr.s() == "SAME_UPPER") { @@ -79,20 +107,7 @@ bool ParseAttrs(const onnx::NodeProto &onnx_node, const std::unique_ptrpush_back(onnx_node_attr.ints(0)); - pads->push_back(onnx_node_attr.ints(kIndex2)); - pads->push_back(onnx_node_attr.ints(kIndex1)); - pads->push_back(onnx_node_attr.ints(kIndex3)); - } else if (onnx_node_attr.ints_size() == kNumShapeSize6) { - *is_3d = true; - pads->push_back(onnx_node_attr.ints(0)); - pads->push_back(onnx_node_attr.ints(kIndex3)); - pads->push_back(onnx_node_attr.ints(kIndex1)); - pads->push_back(onnx_node_attr.ints(kIndex4)); - pads->push_back(onnx_node_attr.ints(kIndex2)); - pads->push_back(onnx_node_attr.ints(kIndex5)); - } + ParsePads(prim, onnx_node_attr, pads, is_3d) } if (attribute_name == "ceil_mode") { *round_mode = (onnx_node_attr.i() == 0) ? mindspore::RoundMode::FLOOR : mindspore::RoundMode::CEIL; diff --git a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc index 6ac2cab1..25f9a927 100644 --- a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc @@ -168,20 +168,32 @@ bool IsParentNode(const AnfNodePtr &input_node, const AnfNodePtr &output_node) { return false; } -STATUS GetSubgraphStopNodes(const std::vector> &boundaries, const size_t ¤t_index, - const std::vector &output_nodes, std::vector *stop_nodes) { - for (int i = current_index - 1; i >= 0; i--) { - auto current_stop_nodes = boundaries[i]; - for (size_t j = 0; j < output_nodes.size(); j++) { - for (size_t k = 0; k < current_stop_nodes.size(); k++) { - if (IsParentNode(current_stop_nodes[k], output_nodes[j])) { - stop_nodes->insert(stop_nodes->end(), current_stop_nodes.begin(), current_stop_nodes.end()); - return lite::RET_OK; - } +bool HasParentInBoundary(const std::vector ¤t_stop_nodes, + const std::vector &output_nodes) { + for (const auto &output_node : output_nodes) { + for (const auto &stop_node : current_stop_nodes) { + if (IsParentNode(stop_node, output_node)) { + return true; } } } - MS_LOG(ERROR) << "Can not found prenode for current output nodes"; + return false; +} + +STATUS GetSubgraphStopNodes(const std::vector> &boundaries, size_t current_index, + const std::vector &output_nodes, std::vector *stop_nodes) { + if (stop_nodes == nullptr) { + MS_LOG(ERROR) << "Stop nodes pointer is null"; + return lite::RET_NULL_PTR; + } + for (int i = static_cast(current_index) - 1; i >= 0; --i) { + const auto ¤t_stop_nodes = boundaries[i]; + if (HasParentInBoundary(current_stop_nodes, output_nodes)) { + stop_nodes->insert(stop_nodes->end(), current_stop_nodes.begin(), current_stop_nodes.end()); + return lite::RET_OK; + } + } + MS_LOG(ERROR) << "Can not find prenode for current output nodes"; return lite::RET_ERROR; } diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc index ada4fe63..4ae298ad 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc @@ -55,6 +55,17 @@ constexpr float kInitOne = 1.0; constexpr size_t kInitBatchSize = 1; constexpr size_t kMaxConfigLen = 1e6; constexpr uint16_t kFloatOne = 15360; + +bool IsMatMulNode(const AnfNodePtr &node) { + return mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulFusion) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul); +} + +bool IsConv2DNode(const AnfNodePtr &node) { + return mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) || + mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2DFusion); +} } // namespace template @@ -633,80 +644,148 @@ lite::STATUS InsertVariableNodePass::BuildVariableNode(const std::shared_ptr node_name_list; auto ret = ParseInsertNode(variable_weights_file, &variable_nodes, &node_name_map, &node_name_list, &has_alpha); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "ParseInsertNode failed!"); + uint32_t matched_num = ProcessGraphNodes(func_graph, variable_nodes, node_name_map, has_alpha, max_weight_batch); + if (matched_num != total_num) { + MS_LOG(ERROR) << "matched num:" << matched_num << " != all node num:" << total_num << "!"; + return RET_ERROR; + } + for (const auto &name : node_name_list) { + auto it = node_name_map.find(name); + if (it != node_name_map.end()) { + const_names->push_back(it->second); + } + } + return RET_OK; +} + +bool InsertVariableNodePass::ValidateInputParameters(const FuncGraphPtr &func_graph) { + if (func_graph == nullptr) { + return false; + } + return true; +} + +uint32_t InsertVariableNodePass::ProcessGraphNodes(const FuncGraphPtr &func_graph, + const std::map> &variable_nodes, + std::unordered_map &node_name_map, + bool has_alpha, int32_t max_weight_batch) { uint32_t matched_num = 0; auto node_list = TopoSort(func_graph->get_return()); for (auto &node : node_list) { - MS_CHECK_TRUE_RET(node != nullptr, false); - auto node_name = node->fullname_with_scope(); - size_t last_slash_pos = node_name.find_last_of('/'); - std::string search_key = ""; - if (utils::isa(node)) { - search_key = node_name; - if (variable_nodes.find(search_key) == variable_nodes.end()) { - continue; - } - auto parameter = node->cast(); - if (parameter == nullptr || !parameter->has_default()) { - continue; - } - ret = RecordParameterVariableName(func_graph, parameter, search_key, false, &node_name_map); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record parameter variable name failed!"); - } else if (utils::isa(node)) { - if (last_slash_pos != std::string::npos) { - search_key = node_name.substr(0, last_slash_pos); - } else { - MS_LOG(INFO) << "Not found last slash, Cnode name:" << node->fullname_with_scope() << "!"; - continue; - } - if (variable_nodes.find(search_key) == variable_nodes.end()) { - continue; - } - auto cnode = utils::cast(node); - MS_CHECK_TRUE_RET(cnode != nullptr, false); - if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulFusion) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), true, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, true, &node_name_map); - } else { - ret = InsertVariableNodeForMatmul(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2DFusion)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), false, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, false, &node_name_map); - } else { - ret = InsertVariableNodeForConv(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else { - continue; - } - } else { + if (node == nullptr) { continue; } - matched_num++; - } - if (matched_num != variable_nodes.size()) { - MS_LOG(ERROR) << "matched num:" << matched_num << " != all node num:" << variable_nodes.size() << "!"; - return RET_ERROR; - } - for (auto s : node_name_list) { - if (node_name_map.find(s) == node_name_map.end()) { + if (ProcessParameterNode(node, variable_nodes, func_graph, node_name_map)) { + matched_num++; + continue; + } + if (ProcessCNode(node, variable_nodes, func_graph, node_name_map, has_alpha, max_weight_batch)) { + matched_num++; continue; } - const_names->push_back(node_name_map[s]); } - return RET_OK; + return matched_num; +} + +bool InsertVariableNodePass::ProcessParameterNode(const AnfNodePtr &node, + const std::map> &variable_nodes, + const FuncGraphPtr &func_graph, + std::unordered_map &node_name_map) { + if (!utils::isa(node)) { + return false; + } + + auto node_name = node->fullname_with_scope(); + if (variable_nodes.find(node_name) == variable_nodes.end()) { + return false; + } + + auto parameter = node->cast(); + if (parameter == nullptr || !parameter->has_default()) { + return false; + } + + auto ret = RecordParameterVariableName(func_graph, parameter, node_name, false, &node_name_map); + MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record parameter variable name failed!"); + return true; +} + +// 处理CNode节点 +bool InsertVariableNodePass::ProcessCNode(const AnfNodePtr &node, + const std::map> &variable_nodes, + const FuncGraphPtr &func_graph, + std::unordered_map &node_name_map, bool has_alpha, + int32_t max_weight_batch) { + if (!utils::isa(node)) { + return false; + } + + auto node_name = node->fullname_with_scope(); + size_t last_slash_pos = node_name.find_last_of('/'); + if (last_slash_pos == std::string::npos) { + MS_LOG(INFO) << "Not found last slash, Cnode name:" << node->fullname_with_scope() << "!"; + return false; + } + + std::string search_key = node_name.substr(0, last_slash_pos); + if (variable_nodes.find(search_key) == variable_nodes.end()) { + return false; + } + + auto cnode = utils::cast(node); + if (cnode == nullptr) { + return false; + } + + if (IsMatMulNode(node)) { + return ProcessMatMulNode(node, cnode, func_graph, variable_nodes.at(search_key), node_name_map, has_alpha, + max_weight_batch); + } + + if (IsConv2DNode(node)) { + return ProcessConv2DNode(node, cnode, func_graph, variable_nodes.at(search_key), node_name_map, has_alpha, + max_weight_batch); + } + + return false; +} + +bool InsertVariableNodePass::ProcessMatMulNode(const AnfNodePtr &node, const CNodePtr &cnode, + const FuncGraphPtr &func_graph, const std::vector &variable_node, + std::unordered_map &node_name_map, + bool has_alpha, int32_t max_weight_batch) { + bool replace_origin = false; + auto ret = CheckOnlyReplace(cnode, variable_node, true, &replace_origin); + MS_CHECK_TRUE_MSG(ret == RET_OK, false, "CheckOnlyReplace failed!"); + + if (replace_origin) { + ret = RecordVariableName(func_graph, cnode, search_key, true, &node_name_map); + } else { + ret = + InsertVariableNodeForMatmul(node, cnode, func_graph, variable_node, &node_name_map, has_alpha, max_weight_batch); + } + + MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record variable name failed!"); + return true; +} + +bool InsertVariableNodePass::ProcessConv2DNode(const AnfNodePtr &node, const CNodePtr &cnode, + const FuncGraphPtr &func_graph, const std::vector &variable_node, + std::unordered_map &node_name_map, + bool has_alpha, int32_t max_weight_batch) { + bool replace_origin = false; + auto ret = CheckOnlyReplace(cnode, variable_node, false, &replace_origin); + MS_CHECK_TRUE_MSG(ret == RET_OK, false, "CheckOnlyReplace failed!"); + + if (replace_origin) { + ret = RecordVariableName(func_graph, cnode, search_key, false, &node_name_map); + } else { + ret = + InsertVariableNodeForConv(node, cnode, func_graph, variable_node, &node_name_map, has_alpha, max_weight_batch); + } + + MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record variable name failed!"); + return true; } bool InsertVariableNodePass::Run(const FuncGraphPtr &graph) { diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h index ba623d67..6c3a666e 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h @@ -38,6 +38,25 @@ class InsertVariableNodePass : public Pass { private: lite::STATUS BuildVariableNode(const std::shared_ptr ¶m, FuncGraphPtr func_graph, std::vector *const_names); + uint32_t ProcessGraphNodes(const FuncGraphPtr &func_graph, + const std::map> &variable_nodes, + std::unordered_map &node_name_map, bool has_alpha, + int32_t max_weight_batch); + bool ProcessParameterNode(const AnfNodePtr &node, const std::map> &variable_nodes, + const FuncGraphPtr &func_graph, + std::unordered_map &node_name_map); + + bool ProcessCNode(const AnfNodePtr &node, const std::map> &variable_nodes, + const FuncGraphPtr &func_graph, std::unordered_map &node_name_map, + bool has_alpha, int32_t max_weight_batch); + bool ProcessMatMulNode(const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, + const std::vector &variable_node, + std::unordered_map &node_name_map, bool has_alpha, + int32_t max_weight_batch); + bool ProcessConv2DNode(const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, + const std::vector &variable_node, + std::unordered_map &node_name_map, bool has_alpha, + int32_t max_weight_batch); lite::STATUS InsertVariableNodeForMatmul(const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, const std::vector &up_shape, std::unordered_map *node_name_map, bool has_alpha, -- Gitee