From 82081095368a14d44b915460c15015380de37876 Mon Sep 17 00:00:00 2001 From: z00621985 Date: Wed, 16 Jul 2025 16:49:43 +0800 Subject: [PATCH 1/3] control flow --- .../delegate/ascend_ge/ge_graph_executor.cc | 2 +- mindspore-lite/tools/common/node_util.cc | 11 + mindspore-lite/tools/common/node_util.h | 2 + .../cxx_api/model/acl/model_converter.cc | 2 +- .../adapter/acl/src/acl_pass_impl.cc | 7 + .../tools/converter/anf_transform.cc | 7 +- .../parser/onnx/onnx_model_parser.cc | 311 ++++++++++-------- .../converter/parser/onnx/onnx_model_parser.h | 19 +- .../fusion/adjust_controlflow_pass.cc | 72 ++-- .../optimizer/graph/cond_tensor_to_scalar.cc | 100 ++++++ .../optimizer/graph/cond_tensor_to_scalar.h | 53 +++ .../optimizer/graph/control_flow_pass.cc | 10 +- .../graph/decrease_transpose_algo.cc | 4 +- .../optimizer/graph/if_to_partial_pass.cc | 223 +++++++++++++ .../optimizer/graph/if_to_partial_pass.h | 69 ++++ .../tools/optimizer/graph/infershape_pass.cc | 3 +- 16 files changed, 732 insertions(+), 163 deletions(-) create mode 100644 mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.cc create mode 100644 mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.h create mode 100644 mindspore-lite/tools/optimizer/graph/if_to_partial_pass.cc create mode 100644 mindspore-lite/tools/optimizer/graph/if_to_partial_pass.h diff --git a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc index fd51b902..05ac6816 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_ge/ge_graph_executor.cc @@ -1181,7 +1181,7 @@ backend::ge_backend::DfGraphPtr GeGraphExecutor::CreateGeGraphOnline( } MS_LOG(INFO) << "extra_variables_names size: " << extra_variables_names.size(); - auto converter = std::make_shared(anf_graph, "", ref_mode_flag_, + auto converter = std::make_shared(anf_graph, "", false, ref_mode_flag_, extra_variables_names, dyn_ref_data_func); backend::ge_backend::BuildGraph(graph_name_, converter, params_vals); auto err_code = backend::ge_backend::ErrCode(converter); diff --git a/mindspore-lite/tools/common/node_util.cc b/mindspore-lite/tools/common/node_util.cc index c8774ad2..b7e0af6a 100644 --- a/mindspore-lite/tools/common/node_util.cc +++ b/mindspore-lite/tools/common/node_util.cc @@ -31,6 +31,7 @@ #include "mindspore/ops/infer/switch.h" #include "mindspore/ops/infer/call.h" #include "mindspore/ops/infer/cxx_api/partial_fusion.h" +#include "mindspore/ops/infer/partial.h" #include "nnacl/op_base.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_a.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" @@ -393,6 +394,16 @@ ValueNodePtr GetPartialFusionPrim() { return partial_anf_prim; } +ValueNodePtr GetPartialPrim() { + auto partial_prim = std::make_shared(); + MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr"); + auto partial_prim_c = partial_prim->GetPrim(); + MS_CHECK_TRUE_MSG(partial_prim_c != nullptr, nullptr, "partial_prim_c is nullptr"); + ValueNodePtr partial_anf_prim = NewValueNode(partial_prim_c); + MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr"); + return partial_anf_prim; +} + ValueNodePtr GetSwitchAnfPrim() { auto switch_prim = std::make_shared(); MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr"); diff --git a/mindspore-lite/tools/common/node_util.h b/mindspore-lite/tools/common/node_util.h index f5dd0dcb..073eb134 100644 --- a/mindspore-lite/tools/common/node_util.h +++ b/mindspore-lite/tools/common/node_util.h @@ -423,6 +423,8 @@ bool IsMakeTuple(const AnfNodePtr &node); ValueNodePtr GetPartialFusionPrim(); +ValueNodePtr GetPartialPrim(); + ValueNodePtr GetSwitchAnfPrim(); ValueNodePtr GetCallAnfPrim(); diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc index 1129ac5b..73f785f2 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc @@ -85,7 +85,7 @@ backend::ge_backend::DfGraphPtr ModelConverter::ConvertFuncGraphToAIR(const Func opt::ReduceOptimization(anf_graph); #endif auto converter = - backend::ge_backend::NewConverter(anf_graph, "", backend::ge_backend::RefModeFlag::kRefModeNone, true); + backend::ge_backend::NewConverter(anf_graph, "", backend::ge_backend::RefModeFlag::kRefModeNone, true, true); std::string compute_graph_name = anf_graph->ToString(); auto option = options_.lock(); if (option != nullptr && !option->GetDumpModelName().empty()) { diff --git a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index f1752372..f3db4cc7 100644 --- a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -109,6 +109,9 @@ constexpr auto kConstFoldPass = "ConstFoldPass"; constexpr auto kAdjustCol2imPass = "AdjustCol2imPass"; constexpr auto kAdjustControlFlowPass = "AdjustControlflowPass"; constexpr auto kRemoveRedundantOpPass = "RemoveRedundantOpPass"; +constexpr auto kControlFlowPass = "ControlFlowPass"; +constexpr auto kIfToPartialPass = "IfToPartialPass"; +constexpr auto kCondTensorToScalarPass = "CondTensorToScalarPass"; constexpr auto kAdjustMatmulPass = "AdjustMatmulPass"; constexpr auto kAdjustAscendQunatPass = "AdjustAscendQunatPass"; constexpr auto kDelRedundantTranspose = "DeleteRedundantTranspose"; @@ -966,6 +969,10 @@ STATUS AclPassImpl::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_ MS_LOG(ERROR) << "convert args to attr pass failed"; return lite::RET_ERROR; } + if (!lite::RunOptimizerPass(func_graph, {kCondTensorToScalarPass, kIfToPartialPass})) { + MS_LOG(ERROR) << "kAdjustControlFlowPass failed!"; + return lite::RET_ERROR; + } if (!lite::RunOptimizerPass(func_graph, {kAdjustControlFlowPass})) { MS_LOG(ERROR) << "kAdjustControlFlowPass failed!"; return lite::RET_ERROR; diff --git a/mindspore-lite/tools/converter/anf_transform.cc b/mindspore-lite/tools/converter/anf_transform.cc index 15010562..ac9129d3 100644 --- a/mindspore-lite/tools/converter/anf_transform.cc +++ b/mindspore-lite/tools/converter/anf_transform.cc @@ -110,6 +110,8 @@ #include "tools/optimizer/fusion/reduce_same_op_in_horizon.h" #include "tools/optimizer/fusion/reshape_shape_fusion.h" #include "tools/optimizer/fusion/transpose_gather_fusion.h" +#include "tools/optimizer/graph/if_to_partial_pass.h" +#include "tools/optimizer/graph/cond_tensor_to_scalar.h" #ifndef ENABLE_CLOUD_FUSION_INFERENCE #include "tools/converter/adapter/acl/acl_pass.h" #endif @@ -850,7 +852,10 @@ bool AnfTransform::StoreBuiltinPass(const std::shared_ptr ¶m) {"AdjustCol2imPass", std::make_shared(), false}, {"AdjustAscendQunatPass", std::make_shared(), false}, {"AddStreamLabelPass", std::make_shared(param), false}, - {"AdjustControlflowPass", std::make_shared(), false}}; + {"AdjustControlflowPass", std::make_shared(), false}, + {"ControlFlowPass", std::make_shared(), false}, + {"IfToPartialPass", std::make_shared(), false}, + {"CondTensorToScalarPass", std::make_shared(), false}}; for (const auto &pass_info : pass_infos) { MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false); PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), std::get(pass_info)); diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc index 3366376a..05edfd89 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -14,6 +14,7 @@ * limitations under the License. */ #include "tools/converter/parser/onnx/onnx_model_parser.h" +#include #include #include #include @@ -537,37 +538,6 @@ FuncGraphPtr ConvertGraph(api::FuncGraphPtr func_graph) { return std::dynamic_pointer_cast(impl); } -STATUS RenameSubGraphInputName(const std::string &if_node_name, const std::vector &if_new_input_not_same, - const std::vector &then_subgraph_extra_inputs, - const std::vector &else_subgraph_extra_inputs) { - std::vector if_input_sub(if_new_input_not_same.begin() + kIndex4, if_new_input_not_same.end()); - std::map input_name_map; - for (size_t i = 0; i < if_input_sub.size(); i++) { - input_name_map[if_input_sub[i]->fullname_with_scope()] = i; - } - for (auto &sub_input : then_subgraph_extra_inputs) { - if (input_name_map.find(sub_input->fullname_with_scope()) == input_name_map.end()) { - MS_LOG(ERROR) << "Extra input name not in input name map, input name:" << sub_input->fullname_with_scope(); - return RET_ERROR; - } - auto input_parameter = sub_input->cast(); - MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter"); - input_parameter->set_name(if_node_name + "_then_branch" + "_input_" + - std::to_string(input_name_map[sub_input->fullname_with_scope()]) + "_parameter"); - } - for (auto &sub_input : else_subgraph_extra_inputs) { - if (input_name_map.find(sub_input->fullname_with_scope()) == input_name_map.end()) { - MS_LOG(ERROR) << "Extra input name not in input name map, input name:" << sub_input->fullname_with_scope(); - return RET_ERROR; - } - auto input_parameter = sub_input->cast(); - MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter"); - input_parameter->set_name(if_node_name + "_then_branch" + "_input_" + - std::to_string(input_name_map[sub_input->fullname_with_scope()]) + "_parameter"); - } - return RET_OK; -} - } // namespace FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, const onnx::GraphProto &subgraph_proto, @@ -580,13 +550,19 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c auto act_outputs_num = node_outputs_num - (node_inputs_num - 2); auto loop_body_graph = std::make_shared(); MS_CHECK_TRUE_MSG(loop_body_graph != nullptr, nullptr, "create loop_body_graph return nullptr"); - std::unordered_map anf_nodes_map; std::vector gen_subgraph_inputs; - auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, &anf_nodes_map, &gen_subgraph_inputs, loop_node_name); + auto loop_body_graph_tree_node = std::make_shared(loop_body_graph); + loop_body_graph_tree_node->anf_nodes_map_ = std::make_shared>(); + loop_body_graph_tree_node->parent = func_graph_tree_; + func_graph_tree_ = loop_body_graph_tree_node; + auto status = ConvertOnnxGraph(subgraph_proto, loop_body_graph, func_graph_tree_->anf_nodes_map_.get(), + &gen_subgraph_inputs, loop_node_name); if (status != RET_OK) { MS_LOG(ERROR) << "convert loop OnnxGraph: " << status; return nullptr; } + func_graph_tree_ = func_graph_tree_->parent; + func_graph_tree_->graph_control_flow_node_map.emplace(loop_body_graph_tree_node, loop_node_name); auto return_node = loop_body_graph->get_return(); MS_CHECK_TRUE_MSG(return_node != nullptr, nullptr, "return node of subgraph is nullptr"); MS_CHECK_TRUE_RET(return_node->size() == DIMENSION_2D, nullptr); @@ -597,8 +573,8 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c gen_subgraph_inputs.end()); std::string max_trip_count_name = subgraph_proto.input(0).name(); - status = - AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, anf_nodes_map, max_trip_count_name, loop_node_name); + status = AddIterNumsUpdateEdge(loop_body_graph, &return_new_inputs, *(loop_body_graph_tree_node->anf_nodes_map_), + max_trip_count_name, loop_node_name); if (status != RET_OK) { MS_LOG(ERROR) << "add iter nums update edge failed: " << status; return nullptr; @@ -608,7 +584,7 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c std::vector body_graph_inputs; body_graph_inputs.reserve(subgraph_proto.input_size()); for (int j = 0; j < subgraph_proto.input_size(); j++) { - body_graph_inputs.emplace_back(anf_nodes_map[subgraph_proto.input(j).name()]); + body_graph_inputs.emplace_back(loop_body_graph_tree_node->anf_nodes_map_->at(subgraph_proto.input(j).name())); } body_graph_inputs.insert(body_graph_inputs.end(), gen_subgraph_inputs.begin(), gen_subgraph_inputs.end()); if (act_outputs_num != 0) { @@ -631,7 +607,9 @@ FuncGraphPtr OnnxModelParser::BuildBodyGraph(const onnx::NodeProto &loop_node, c MS_CHECK_TRUE_RET(body_graph_inputs[j] != nullptr, nullptr); auto body_input = body_graph_inputs[j]->cast(); MS_CHECK_TRUE_RET(body_input != nullptr, nullptr); - body_input->set_name(body_graph_name + "_input_" + std::to_string(j) + "_parameter"); + auto param_name = body_input->name(); + func_graph_tree_->control_flow_node_inputs_map[root_while_node].emplace_back(param_name); + body_input->set_name(body_graph_name + "_" + param_name + "_input_" + std::to_string(j) + "_parameter"); } for (size_t j = 1; j < return_new_inputs.size(); j++) { if (utils::isa(return_new_inputs[j])) { @@ -685,8 +663,133 @@ STATUS CheckOnnxModel(const onnx::GraphProto &onnx_graph) { } } // namespace +STATUS RecursiveSetExtraInput(FuncGraphTreePtr child, std::string onnx_name, bool isMindIR) { + auto parent = child->parent; + MS_CHECK_TRUE_MSG(parent != nullptr, RET_ERROR, "parent is nullptr"); + auto control_flow_node_name = parent->graph_control_flow_node_map[child]; + auto control_node = parent->anf_nodes_map_->at(control_flow_node_name); + for (size_t i = 0; i < parent->control_flow_node_inputs_map[control_node].size(); i++) { + std::string name = parent->control_flow_node_inputs_map[control_node][i]; + if (name == onnx_name) { + auto input_node = parent->anf_nodes_map_->at(onnx_name); + MS_CHECK_TRUE_MSG(input_node != nullptr, RET_ERROR, "name is not in anf_nodes"); + std::string node_fullscope_with_name = GetValue(child->graph->get_attr("graph_name")) + "_" + + onnx_name + "_" + std::to_string(i) + "_parameter"; + child->anf_nodes_map_->at(onnx_name)->set_abstract(input_node->abstract()); + child->anf_nodes_map_->at(onnx_name)->cast()->set_name(node_fullscope_with_name); + return RET_OK; + } + } + if (parent->anf_nodes_map_->find(onnx_name) == parent->anf_nodes_map_->end()) { + auto ext_subgraph_input = parent->graph->add_parameter(); + ext_subgraph_input->set_name(onnx_name); + parent->anf_nodes_map_->emplace(onnx_name, ext_subgraph_input); + auto status = RecursiveSetExtraInput(parent, onnx_name, isMindIR); + MS_CHECK_TRUE_RET(status == RET_OK, RET_ERROR); + } else if (parent->anf_nodes_map_->at(onnx_name)->abstract() == nullptr) { + auto status = RecursiveSetExtraInput(parent, onnx_name, isMindIR); + MS_CHECK_TRUE_RET(status == RET_OK, RET_ERROR); + } + auto input_node = parent->anf_nodes_map_->at(onnx_name); + // auto is_has_default_parameter = input_node->isa() && input_node->cast()->has_default(); + std::string node_fullscope_with_name = ""; + // if (is_has_default_parameter) { + // auto parameter = input_node->cast(); + // auto tensor_info = parameter->default_param()->cast(); + // auto copy_tensor_info = + // CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(), tensor_info->data_type()); + // if (copy_tensor_info == nullptr) { + // MS_LOG(ERROR) << "memcpy failed."; + // return RET_ERROR; + // } + // child->anf_nodes_map_->at(onnx_name)->cast()->set_default_param(copy_tensor_info); + // node_fullscope_with_name = GetValue(child->graph->get_attr("graph_name")) + "_" + onnx_name; + // } else { + control_node->cast()->add_input(input_node); + parent->control_flow_node_inputs_map[control_node].emplace_back(onnx_name); + node_fullscope_with_name = GetValue(child->graph->get_attr("graph_name")) + "_" + onnx_name + "_" + + std::to_string(parent->control_flow_node_inputs_map[control_node].size() - 1) + + "_parameter"; + // In loop, cond_graph need to pass output of body_graph to body_graph + // so cond_graph need to add body_graph new input, and body_graph needs add input to its output, + if (opt::CheckPrimitiveType(control_node, prim::kPrimWhile)) { + auto control_cnode = control_node->cast(); + auto cond_graph = GetValueNode>(control_cnode->input(1)); + size_t name_index = cond_graph->get_inputs().size(); + auto input_parameter = cond_graph->add_parameter(); + MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "create input_parameter return nullptr"); + input_parameter->set_name(cond_graph->get_attr("graph_name")->ToString() + "_input_" + std::to_string(name_index) + + "_parameter"); + auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32); + input_parameter->set_abstract(input_abstract); + auto return_node = child->graph->get_return(); + auto return_tuple_cnode = return_node->input(1)->cast(); + auto return_new_inputs = return_tuple_cnode->inputs(); + return_new_inputs.emplace_back(child->anf_nodes_map_->at(onnx_name)); + return_tuple_cnode->set_inputs(return_new_inputs); + } + // } + child->anf_nodes_map_->at(onnx_name)->set_abstract(input_node->abstract()); + child->anf_nodes_map_->at(onnx_name)->cast()->set_name(node_fullscope_with_name); + return RET_OK; +} + +STATUS OnnxModelParser::HandleGraphsInputs() { + std::queue q; + std::vector> levelOrderTree; + std::vector root = {func_graph_tree_}; + bool isMindIR = save_type_ == kMindIR; + levelOrderTree.emplace_back(root); + q.emplace(func_graph_tree_); + while (!q.empty()) { + std::vector levelTreeNode = {}; + int size = q.size(); + for (auto i = 0; i < size; i++) { + auto &node = q.front(); + for (auto &pair : node->graph_control_flow_node_map) { + levelTreeNode.emplace_back(pair.first); + q.emplace(pair.first); + } + q.pop(); + } + if (!levelTreeNode.empty()) { + levelOrderTree.emplace_back(levelTreeNode); + } + } + std::reverse(levelOrderTree.begin(), levelOrderTree.end()); + for (auto &node : levelOrderTree) { + for (auto &child : node) { + for (auto parameter : child->graph->parameters()) { + MS_EXCEPTION_IF_NULL(parameter); + MS_CHECK_TRUE_RET(parameter->isa(), RET_ERROR); + auto param = parameter->cast(); + if (param->abstract() != nullptr) continue; + auto name = param->name(); + auto status = RecursiveSetExtraInput(child, name, isMindIR); + MS_CHECK_TRUE_RET(status == RET_OK, RET_ERROR); + } + } + } + return RET_OK; +} + +size_t getInputsIndex(AnfNodePtr node) { + auto node_name = node->fullname_with_scope(); + auto last_underline = node_name.find_last_of("_"); + node_name = node_name.substr(0, last_underline); + last_underline = node_name.find_last_of("_"); + try { + size_t index = static_cast(std::stoi(node_name.substr(last_underline + 1))); + return index; + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Get node(" << node_name << ") index failed: " << e.what(); + return RET_ERROR; + } +} + api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &flag) { auto model_file = flag.model_file; + save_type_ = flag.save_type; NotSupportOp::GetInstance()->set_fmk_type("ONNX"); auto graph = std::make_shared(); MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "create FuncGraph failed"); @@ -698,6 +801,7 @@ api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &f return nullptr; } MS_ASSERT(onnx_root_graph_ != nullptr); + func_graph_tree_ = std::make_shared(graph); status = ConvertOnnxGraph(onnx_root_graph_, graph, &anf_nodes_map_, {}, "root_node"); if (RET_OK != status) { @@ -705,12 +809,23 @@ api::FuncGraphPtr OnnxModelParser::Parse(const converter::ConverterParameters &f MS_LOG(ERROR) << "convert onnx graph failed."; return nullptr; } + func_graph_tree_->anf_nodes_map_ = std::make_shared>(anf_nodes_map_); + status = HandleGraphsInputs(); + MS_CHECK_TRUE_RET(status == RET_OK, nullptr); + static auto root_func_manager = Manage(graph); for (auto &subgraph : all_subgraphs_) { MS_CHECK_TRUE_RET(subgraph != nullptr, nullptr); subgraph->set_manager(root_func_manager); subgraph->set_attr("fmk", MakeValue(static_cast(converter::kFmkTypeOnnx))); + auto parameters = subgraph->parameters(); + std::stable_partition(parameters.begin(), parameters.end(), + [](AnfNodePtr node) { return !dyn_cast(node)->has_default(); }); + auto input_size = subgraph->get_inputs().size(); + std::sort(parameters.begin(), parameters.begin() + input_size, + [](const AnfNodePtr &a, const AnfNodePtr &b) { return getInputsIndex(a) < getInputsIndex(b); }); + subgraph->set_parameters(parameters); } graph->set_attr("graph_name", MakeValue("main_graph")); graph->set_attr("fmk", MakeValue(static_cast(converter::kFmkTypeOnnx))); @@ -943,27 +1058,16 @@ STATUS OnnxModelParser::ConvertNodes(const onnx::GraphProto &onnx_graph, const F STATUS OnnxModelParser::ConvertIfSubgraph(const onnx::GraphProto &subgraph_proto, const FuncGraphPtr &subgraph, const std::string &subgraph_name, const std::string &if_node_name, - const std::string &root_node_name, std::vector *subgraph_extra_inputs) { MS_CHECK_TRUE_RET(subgraph != nullptr, RET_NULL_PTR); - std::unordered_map anf_nodes_map; - auto status = ConvertOnnxGraph(subgraph_proto, subgraph, &anf_nodes_map, subgraph_extra_inputs, if_node_name); + func_graph_tree_->anf_nodes_map_ = std::make_shared>(); + auto status = ConvertOnnxGraph(subgraph_proto, subgraph, func_graph_tree_->anf_nodes_map_.get(), + subgraph_extra_inputs, if_node_name); if (status != RET_OK) { MS_LOG(ERROR) << "convert loop OnnxGraph failed"; return status; } subgraph->set_attr("graph_name", MakeValue(subgraph_name)); - // update subgraph in out name - for (int j = 0; j < subgraph_proto.input_size(); j++) { - auto input_anode_iter = anf_nodes_map.find(subgraph_proto.input(j).name()); - if (input_anode_iter == anf_nodes_map.end()) { - MS_LOG(ERROR) << "cannot find input anode"; - return RET_ERROR; - } - auto input_parameter = input_anode_iter->second->cast(); - MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_ERROR, "subgraph input should be a parameter"); - input_parameter->set_name(subgraph_name + "_input_" + std::to_string(j) + "_parameter"); - } auto return_node = subgraph->get_return(); MS_CHECK_TRUE_MSG(return_node != nullptr, RET_ERROR, "subgraph has no return"); MS_CHECK_GE(return_node->size(), kInputSize1, RET_ERROR); @@ -1007,8 +1111,15 @@ STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, subgraph_name = if_node_name + "_then_branch"; then_branch_graph = std::make_shared(); MS_CHECK_TRUE_MSG(then_branch_graph != nullptr, RET_NULL_PTR, "create then_branch_graph return nullptr"); - auto status = ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, root_node_name, - &then_subgraph_extra_inputs); + MS_LOG(WARNING) << "Start convert if then_branch subgraph " << subgraph_name << " finished"; + auto then_graph_tree_node = std::make_shared(then_branch_graph); + then_graph_tree_node->parent = func_graph_tree_; + func_graph_tree_ = then_graph_tree_node; + auto status = + ConvertIfSubgraph(subgraph_proto, then_branch_graph, subgraph_name, if_node_name, &then_subgraph_extra_inputs); + func_graph_tree_ = func_graph_tree_->parent; + func_graph_tree_->graph_control_flow_node_map.emplace(then_graph_tree_node, if_node_name); + MS_LOG(WARNING) << "Convert if then_branch subgraph " << subgraph_name << " finished"; if (status != RET_OK) { MS_LOG(ERROR) << "build if node else branch failed."; } @@ -1016,13 +1127,18 @@ STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, subgraph_name = if_node_name + "_else_branch"; else_branch_graph = std::make_shared(); MS_CHECK_TRUE_MSG(else_branch_graph != nullptr, RET_NULL_PTR, "create else_branch_graph return nullptr"); - auto status = ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, root_node_name, - &else_subgraph_extra_inputs); + MS_LOG(WARNING) << "Start convert if else_branch subgraph " << subgraph_name << " finished"; + auto else_graph_tree_node = std::make_shared(else_branch_graph); + else_graph_tree_node->parent = func_graph_tree_; + func_graph_tree_ = else_graph_tree_node; + auto status = + ConvertIfSubgraph(subgraph_proto, else_branch_graph, subgraph_name, if_node_name, &else_subgraph_extra_inputs); + func_graph_tree_ = func_graph_tree_->parent; + func_graph_tree_->graph_control_flow_node_map.emplace(else_graph_tree_node, if_node_name); + MS_LOG(WARNING) << "Convert if else_branch subgraph " << subgraph_name << " finished"; if (status != RET_OK) { MS_LOG(ERROR) << "build if node else branch failed."; } - } else { - continue; } } all_subgraphs_.emplace_back(then_branch_graph); @@ -1037,19 +1153,14 @@ STATUS OnnxModelParser::ConvertIfOnnxNode(const onnx::NodeProto &onnx_node, if_new_inputs.insert(if_new_inputs.begin() + 1, {then_value_node, else_value_node}); std::vector if_new_input_not_same{}; - std::set if_set{}; + std::set if_set{}; for (auto &input : if_new_inputs) { - if (if_set.find(input) != if_set.end()) { + if (if_set.find(input->fullname_with_scope()) != if_set.end()) { continue; } + MS_LOG(WARNING) << "if_new_inputs name: " << input->fullname_with_scope(); if_new_input_not_same.push_back(input); - if_set.insert(input); - } - auto status = RenameSubGraphInputName(if_node_name, if_new_input_not_same, then_subgraph_extra_inputs, - else_subgraph_extra_inputs); - if (status != RET_OK) { - MS_LOG(ERROR) << "RenameSubGraphInputName failed!"; - return RET_ERROR; + if_set.insert(input->fullname_with_scope()); } root_if_node->set_inputs(if_new_input_not_same); return RET_OK; @@ -1075,73 +1186,11 @@ STATUS OnnxModelParser::BuildCNode(const onnx::NodeProto &onnx_node, const FuncG op_inputs.push_back(anf_nodes_map->at(input_name)); } else { // subgraph may refer root graph nodes - std::vector need_add_input_nodes; auto ext_subgraph_input = anf_graph->add_parameter(); MS_CHECK_TRUE_MSG(ext_subgraph_input != nullptr, RET_NULL_PTR, "create parameter return nullptr"); - ParameterPtr inner_extra_paramter = nullptr; - while (!loop_name.empty() && child_root_map_.find(loop_name) != child_root_map_.end()) { - auto cur_node_map = control_nodes_map_[loop_name]; - CHECK_NULL_RETURN(cur_node_map); - if (cur_node_map->find(input_name) != cur_node_map->end()) { - auto outside_input_node = cur_node_map->at(input_name); - CHECK_NULL_RETURN(outside_input_node); - // copy outside input parameter value to inside subgraph - ext_subgraph_input->set_abstract(outside_input_node->abstract()); - ext_subgraph_input->set_name(outside_input_node->fullname_with_scope()); - if (outside_input_node->isa()) { - auto parameter = outside_input_node->cast(); - if (!parameter->has_default()) { - MS_LOG(ERROR) << "outside_input_node should has data! node name:" << onnx_node.name() - << "outside input name:" << outside_input_node->fullname_with_scope() - << " input name:" << input_name; - return RET_ERROR; - } - auto tensor_info = parameter->default_param()->cast(); - CHECK_NULL_RETURN(tensor_info); - auto copy_tensor_info = CreateTensorInfo(tensor_info->data_c(), tensor_info->Size(), tensor_info->shape(), - tensor_info->data_type()); - if (copy_tensor_info == nullptr) { - MS_LOG(ERROR) << "memcpy failed."; - return RET_ERROR; - } - ext_subgraph_input->set_default_param(copy_tensor_info); - } else { - // output inside cnode need make extra input - CHECK_NULL_RETURN(graph_inputs); - graph_inputs->emplace_back(ext_subgraph_input); - if (cur_node_map->find(loop_name) != cur_node_map->end()) { - CHECK_NULL_RETURN(cur_node_map->at(loop_name)); - auto control_node = cur_node_map->at(loop_name)->cast(); - MS_CHECK_TRUE_RET(control_node != nullptr, RET_NULL_PTR); - control_node->add_input(outside_input_node); - } else { - MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map."; - return RET_ERROR; - } - for (auto &control_node : need_add_input_nodes) { - CHECK_NULL_RETURN(control_node); - auto func_graph = control_node->func_graph(); - auto extra_input_parameter = func_graph->add_parameter(); - MS_CHECK_TRUE_MSG(extra_input_parameter != nullptr, RET_NULL_PTR, "create parameter return nullptr"); - extra_input_parameter->set_name(input_name); - extra_input_parameter->set_abstract(outside_input_node->abstract()); - control_node->add_input(extra_input_parameter); - } - } - op_inputs.push_back(ext_subgraph_input); - anf_nodes_map->emplace(input_name, ext_subgraph_input); - break; - } else { - if (cur_node_map->find(loop_name) != cur_node_map->end()) { - CHECK_NULL_RETURN(cur_node_map->at(loop_name)); - need_add_input_nodes.emplace_back(cur_node_map->at(loop_name)->cast()); - } else { - MS_LOG(ERROR) << "loop node: " << loop_name << " not found in cur node map."; - return RET_ERROR; - } - loop_name = child_root_map_[loop_name]; - } - } + ext_subgraph_input->set_name(input_name); + op_inputs.push_back(ext_subgraph_input); + anf_nodes_map->emplace(input_name, ext_subgraph_input); } } auto new_cnode = anf_graph->NewCNode(primitive_c, op_inputs); diff --git a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.h index f8a52e99..90ce0dd8 100644 --- a/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore-lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,19 @@ namespace mindspore { namespace lite { +struct FuncGraphTree; +using FuncGraphTreePtr = std::shared_ptr; +struct FuncGraphTree { + FuncGraphPtr graph; + FuncGraphTreePtr parent; + // funcgraph -> control_flow_node name + std::unordered_map graph_control_flow_node_map; + // onnx node name -> funcgraph node + std::shared_ptr> anf_nodes_map_; + // control flow node -> has added onnx input name + std::unordered_map> control_flow_node_inputs_map; + explicit FuncGraphTree(FuncGraphPtr graph) : graph(std::move(graph)) {} +}; class OnnxModelParser : public converter::ModelParser { public: OnnxModelParser() = default; @@ -76,16 +90,19 @@ class OnnxModelParser : public converter::ModelParser { int *cond_graph_input_num); STATUS ConvertIfSubgraph(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &anf_graph, const std::string &subgrah_name, const std::string &if_node_name, - const std::string &root_node_name, std::vector *subgraph_extra_inputs); + std::vector *subgraph_extra_inputs); + STATUS HandleGraphsInputs(); onnx::ModelProto onnx_model_{}; onnx::GraphProto onnx_root_graph_{}; std::vector all_subgraphs_{}; + FuncGraphTreePtr func_graph_tree_; std::unordered_map anf_nodes_map_{}; std::unordered_map *> control_nodes_map_{}; std::unordered_map child_root_map_{}; // for nest control flow node std::string model_file_{}; bool has_subgraph_ = false; + ModelType save_type_ = kMindIR; }; } // namespace lite } // namespace mindspore diff --git a/mindspore-lite/tools/optimizer/fusion/adjust_controlflow_pass.cc b/mindspore-lite/tools/optimizer/fusion/adjust_controlflow_pass.cc index d3f2ff98..685f08a7 100644 --- a/mindspore-lite/tools/optimizer/fusion/adjust_controlflow_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/adjust_controlflow_pass.cc @@ -30,6 +30,7 @@ #include "mindspore/ops/op_def/lite_ops.h" #include "ops_utils/op_constants.h" #include "tools/converter/export_model.h" +#include "src/common/utils.h" namespace mindspore { namespace opt { @@ -56,8 +57,8 @@ int32_t AdjustControlflowPass::AdjustControlflow(const CNodePtr &cnode, const Fu MS_LOG(ERROR) << "cnode is nullptr!"; return lite::RET_ERROR; } - if (cnode->size() < ops::kSize3) { - MS_LOG(ERROR) << "If node size should larger than 3! current size:" << cnode->size(); + if (cnode->size() < ops::kSize4) { + MS_LOG(ERROR) << "If node size should larger than 4! current size:" << cnode->size(); return lite::RET_ERROR; } auto value_node = cnode->input(0)->cast(); @@ -125,32 +126,61 @@ int32_t AdjustControlflowPass::AdjustControlflow(const CNodePtr &cnode, const Fu return lite::RET_OK; } +void moveInputsToFront(const FuncGraphPtr &func_graph) { + auto parameters = func_graph->parameters(); + std::stable_partition(parameters.begin(), parameters.end(), + [](AnfNodePtr node) { return !dyn_cast(node)->has_default(); }); + func_graph->set_parameters(parameters); +} + +STATUS delRedundantParameter(const FuncGraphPtr &func_graph) { + CHECK_NULL_RETURN(func_graph); + auto nodes = TopoSort(func_graph->get_return()); + auto parameters = func_graph->parameters(); + for (auto ¶meter : parameters) { + CHECK_NULL_RETURN(parameter); + if (std::find(nodes.begin(), nodes.end(), parameter) == nodes.end() && parameter->isa()) { + func_graph->DropNode(parameter); + } + } + // for control-flow node, graph inputs must be in the front of parameter vector + moveInputsToFront(func_graph); + for (auto &node : nodes) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + for (size_t i = 0; i < cnode->size(); ++i) { + auto origin_input = cnode->input(i); + MS_CHECK_TRUE_RET(origin_input != nullptr, lite::RET_ERROR); + if (IsValueNode(origin_input)) { + auto sub_func_graph = GetValueNode(origin_input); + MS_CHECK_TRUE_RET(sub_func_graph != nullptr, lite::RET_ERROR); + auto delStatus = delRedundantParameter(sub_func_graph); + if (delStatus != lite::RET_OK) { + MS_LOG(INFO) << "STOP"; + } + MS_CHECK_TRUE_MSG(delStatus != RET_ERROR, lite::RET_ERROR, "delRedundantParameter sub_func_graph failed!"); + } + } + } + return lite::RET_OK; +} + bool AdjustControlflowPass::Run(const FuncGraphPtr &func_graph) { MS_CHECK_TRUE_RET(func_graph != nullptr, false); - auto node_list = TopoSort(func_graph->get_return()); auto manager = Manage(func_graph, true); if (manager == nullptr) { MS_LOG(ERROR) << "Manager is nullptr!"; return false; } - for (auto &node : node_list) { - if (!utils::isa(node)) { - continue; - } - if (!opt::CheckPrimitiveType(node, prim::kPrimIf)) { - continue; - } - MS_LOG(INFO) << "begin process if node"; - auto if_node = node->cast(); - MS_CHECK_TRUE_RET(if_node != nullptr, false); - if (AdjustControlflow(if_node, func_graph) != lite::RET_OK) { - MS_LOG(ERROR) << "This node run AdjustControlflow failed! Node_name is: " << if_node->fullname_with_scope() - << "!"; - return false; - } - MS_LOG(INFO) << "This node run AdjustControlflowPass success : " << if_node->fullname_with_scope(); - } - MS_LOG(INFO) << "AdjustControlflowPass end."; + auto new_param = std::make_shared(); + new_param->fmk_type = converter::kFmkTypeMs; + new_param->save_type = kMindIR; + + auto status = delRedundantParameter(func_graph); + MS_CHECK_TRUE_MSG(status == RET_OK, lite::RET_NULL_PTR, "delRedundantParameter main_graph failed!"); + UpdateManager(func_graph); return true; } } // namespace opt diff --git a/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.cc b/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.cc new file mode 100644 index 00000000..2c421d4f --- /dev/null +++ b/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.cc @@ -0,0 +1,100 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define USE_DEPRECATED_API +#include "tools/optimizer/graph/cond_tensor_to_scalar.h" +#include +#include +#include +#include +#include +#include "mindspore/ops/op_def/sequence_ops.h" +#include "mindspore/ops/op_def/lite_ops.h" +#include "mindspore/ops/op_def/framework_ops.h" +#include "include/errorcode.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "tools/common/node_util.h" +#include "nnacl/op_base.h" +#include "mindspore/core/include/abstract/abstract_function.h" +#include "tools/common/tensor_util.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_r.h" + +namespace mindspore::opt { + +size_t GetNodeIndex(const AnfNodePtr &input, const CNodePtr &user_node) { + MS_EXCEPTION_IF_NULL(input); + MS_EXCEPTION_IF_NULL(user_node); + + AnfNodePtrList input_list = user_node->inputs(); + auto pos = std::find(input_list.begin(), input_list.end(), input); + if (pos == input_list.end()) { + MS_LOG_WITH_NODE(EXCEPTION, user_node) + << input->fullname_with_scope() << " is not the input of " << user_node->fullname_with_scope(); + } + + // The first input is Primitive and needs to be skipped. + return std::distance(input_list.begin() + 1, pos); +} + +bool CondTensorToScalarPass::Run(const FuncGraphPtr &fg) { + MS_ASSERT(fg != nullptr); + to_process_q.push_back(fg); + while (!to_process_q.empty()) { + auto cur_fg = to_process_q.front(); + to_process_q.pop_front(); + auto manager = cur_fg->manager(); + auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString(); + std::set control_flow_nodes; + auto node_list = TopoSort(cur_fg->get_return()); + for (auto &node : node_list) { + MS_ASSERT(node != nullptr); + if (utils::isa(node) && CheckPrimitiveType(node, prim::kPrimIf)) { + control_flow_nodes.insert(node); + } + } + if (control_flow_nodes.empty()) { + MS_LOG(INFO) << cur_fg_name << " not found control flow op, no need to process."; + continue; + } + for (const auto &if_node : control_flow_nodes) { + auto prim = NewValueNode(std::make_shared(prim::kPrimReshape->name())); + auto if_cnode = if_node->cast(); + auto cond_cnode = if_cnode->input(kIfCondIndex); + std::string cond_cnode_name = cond_cnode->fullname_with_scope(); + if (std::accumulate(cond_cnode->Shape()->GetShapeVector().begin(), cond_cnode->Shape()->GetShapeVector().end(), 1, + [](int a, int b) { return a * b; }) != 1) { + MS_LOG(ERROR) << cond_cnode_name << "'s shape is " << cond_cnode->Shape()->GetShapeVector() << " is not 1"; + return false; + } + auto reshape_node_name = cond_cnode->fullname_with_scope() + "_reshape"; + auto param_node = opt::BuildIntVecParameterNode(cur_fg, std::vector(), reshape_node_name); + AnfNodePtrList inputs = {prim, cond_cnode, param_node}; + CNodePtr reshape_to_tensor = cur_fg->NewCNode(inputs); + to_process_q.push_back(GetValueNode>(if_cnode->input(kIfThenIndex))); + to_process_q.push_back(GetValueNode>(if_cnode->input(kIfElseIndex))); + // set abstract + TypeId type_id; + (void)GetDataTypeFromAnfNode(cond_cnode, &type_id); + + auto tmp_abstract = std::make_shared(kValueAny, TypeIdToType(type_id)); + reshape_to_tensor->set_abstract(tmp_abstract); + manager->SetEdge(if_cnode, GetNodeIndex(cond_cnode, if_cnode) + 1, reshape_to_tensor); + } + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.h b/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.h new file mode 100644 index 00000000..ed45f869 --- /dev/null +++ b/mindspore-lite/tools/optimizer/graph/cond_tensor_to_scalar.h @@ -0,0 +1,53 @@ +#ifndef COND_TENSOR_TO_SCALAR_H +#define COND_TENSOR_TO_SCALAR_H + +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include "tools/optimizer/common/gllo_utils.h" +#include "schema/inner/model_generated.h" +#include "include/backend/optimizer/pass.h" + +namespace mindspore::opt { +class CondTensorToScalarPass : public Pass { + public: + CondTensorToScalarPass() : Pass("cond_tensor_to_scalar_pass") {} + ~CondTensorToScalarPass() override = default; + bool Run(const FuncGraphPtr &fg) override; + + private: + const size_t kCNodePrimIndex = 0; + const size_t kCNodeFirstInputIndex = 1; + const size_t kCNodeSecondInputIndex = 2; + + const size_t kGetItemInputSize = 3; + const size_t kPartialFirstInputSize = 2; + + const size_t kIfMinInputSize = 4; + const size_t kIfThenIndex = 1; + const size_t kIfElseIndex = 2; + const size_t kIfCondIndex = 3; + + std::deque to_process_q{}; +}; +} // namespace mindspore::opt + +#endif // COND_TENSOR_TO_SCALAR_H diff --git a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc index 13138233..c93f8afb 100644 --- a/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/control_flow_pass.cc @@ -549,12 +549,12 @@ int ControlFlowPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, } else { for (auto &fg_input : origin_then_fg_inputs) { auto fg_input_name = fg_input->fullname_with_scope(); - auto pos = partial_fg_name.size() + sizeof("_input_"); - auto pos2 = fg_input_name.find('_', pos); - auto idx_str = fg_input_name.substr(pos - 1, pos2 - pos + 1); - auto partial_idx = 0; + auto last_underline = fg_input_name.find_last_of("_"); + fg_input_name = fg_input_name.substr(0, last_underline); + last_underline = fg_input_name.find_last_of("_"); + size_t partial_idx = 0; try { - partial_idx = std::stoi(idx_str); + partial_idx = static_cast(std::stoi(fg_input_name.substr(last_underline + 1))); } catch (const std::exception &e) { MS_LOG(ERROR) << "Get index failed: " << e.what(); return RET_FAILED; diff --git a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc index 5d984bf7..8ee4aa38 100644 --- a/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore-lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -612,6 +612,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); + size_t start_index = CheckPrimitiveType(cnode, prim::kPrimWhile) ? kInputSizeThree : kInputSizeFour; auto sub_inputs = sub_graph->get_inputs(); sub_inputs_map_[sub_graph] = sub_inputs; for (auto &node : sub_inputs) { @@ -623,7 +624,7 @@ int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGra last_underline = node_name.find_last_of("_"); size_t index = 0; try { - index = std::stoi(node_name.substr(last_underline + 1)) + static_cast(kInputSizeThree); + index = std::stoi(node_name.substr(last_underline + 1)) + static_cast(start_index); } catch (const std::exception &e) { MS_LOG(ERROR) << "Get index failed: " << e.what(); return lite::RET_ERROR; @@ -883,6 +884,7 @@ bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); node_infer_shape_.Init(fmk_type_, train_flag_); transpose_strategy_.Init(fmk_type_, train_flag_); + sub_inputs_map_ = {}; if (!delete_redundant_transpose_.Run(func_graph)) { MS_LOG(ERROR) << "Run delete-redundant-transpose pass failed."; return false; diff --git a/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.cc b/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.cc new file mode 100644 index 00000000..d04dc7c5 --- /dev/null +++ b/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.cc @@ -0,0 +1,223 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define USE_DEPRECATED_API +#include "tools/optimizer/graph/if_to_partial_pass.h" +#include +#include +#include +#include +#include "mindspore/ops/op_def/sequence_ops.h" +#include "mindspore/ops/op_def/lite_ops.h" +#include "mindspore/ops/op_def/framework_ops.h" +#include "include/errorcode.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "tools/common/node_util.h" +#include "nnacl/op_base.h" +#include "mindspore/core/include/abstract/abstract_function.h" +#include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_i.h" + +namespace mindspore::opt { + +int IfToPartialPass::ColloctIfNodes(const FuncGraphPtr &fg, std::set *control_flow_node) { + // notice: fg->nodes() is not work in this pass, cause too many useless parameter have been created. + auto node_list = TopoSort(fg->get_return()); + for (auto &node : node_list) { + MS_ASSERT(node != nullptr); + if (utils::isa(node) && CheckPrimitiveType(node, prim::kPrimIf)) { + control_flow_node->insert(node); + } + } + + return RET_SUCCESS; +} + +int IfToPartialPass::CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg, + std::vector *then_partial_cnode_inputs) { + auto if_inputs = if_cnode->inputs(); + auto fg_name_attr = partial_fg->get_attr("graph_name"); + MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED); + auto partial_fg_name = fg_name_attr->ToString(); + std::vector if_external_inputs{}; + if_external_inputs.assign(if_inputs.begin() + kIfMinInputSize, if_inputs.end()); + auto origin_then_fg_inputs = partial_fg->get_inputs(); + if (if_external_inputs.size() < origin_then_fg_inputs.size()) { + MS_LOG(ERROR) << "graph is not right."; + return RET_FAILED; + } + // collect inputs for control flow node in the order of graph input + for (auto &fg_input : origin_then_fg_inputs) { + auto fg_input_name = fg_input->fullname_with_scope(); + std::string p = "_parameter"; + size_t end_pos = fg_input_name.size() - p.size(); + size_t start_pos = fg_input_name.rfind("_", end_pos - 1); + string idx_str = fg_input_name.substr(start_pos + 1, end_pos - start_pos - 1); + auto partial_idx = 0; + try { + partial_idx = std::stoi(idx_str); + } catch (const std::exception &e) { + MS_LOG(ERROR) << "Get index failed: " << e.what(); + return RET_FAILED; + } + then_partial_cnode_inputs->push_back(if_external_inputs.at(partial_idx)); + } + return RET_SUCCESS; +} + +int IfToPartialPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index, const CNodePtr &if_cnode, + CNodePtr *then_partial_cnode) { + auto then_vnode = if_cnode->input(index); + MS_ASSERT(then_vnode != nullptr); + auto then_fg = GetValueNode>(then_vnode); + MS_CHECK_TRUE_MSG(then_fg != nullptr, RET_FAILED, "Get value as func_graph failed."); + + // create then partial node + ValueNodePtr then_partial_anf_primitive = lite::GetPartialPrim(); + MS_CHECK_TRUE_MSG(then_partial_anf_primitive != nullptr, RET_FAILED, "GetPartialPrim failed."); + std::vector then_partial_cnode_inputs{then_partial_anf_primitive, then_vnode}; + if (CreateIfPartialNodeExternalInputs(if_cnode, then_fg, &then_partial_cnode_inputs) != RET_SUCCESS) { + MS_LOG(ERROR) << "CreateIfPartialNodeExternalInputs failed."; + return RET_FAILED; + } + // set fg inputs to then_partial_cnode inputs + auto origin_then_fg_inputs = then_fg->get_inputs(); + *then_partial_cnode = fg->NewCNode(then_partial_cnode_inputs); + MS_CHECK_TRUE_MSG(*then_partial_cnode != nullptr, RET_FAILED, "new cnode is nullptr"); + auto fg_name_attr = then_fg->get_attr("graph_name"); + MS_CHECK_TRUE_RET(fg_name_attr != nullptr, RET_FAILED); + auto then_fg_name = fg_name_attr->ToString(); + (*then_partial_cnode)->set_fullname_with_scope("partial_" + then_fg_name); + AbstractBasePtrList then_partial_args_list; + (void)std::for_each(then_partial_cnode_inputs.cbegin() + 2, then_partial_cnode_inputs.cend(), + [&then_partial_args_list](const AnfNodeWeakPtr &weak_node) { + auto node = weak_node.lock(); + MS_EXCEPTION_IF_NULL(node); + (void)then_partial_args_list.emplace_back(node->abstract()); + }); + auto then_partial_abs = std::make_shared( + then_fg->ToAbstract()->cast(), then_partial_args_list, *then_partial_cnode); + (*then_partial_cnode)->set_abstract(then_partial_abs); + to_process_q.push_back(then_fg); + + return RET_SUCCESS; +} + +int IfToPartialPass::CreateIfElsePartialNode(const FuncGraphPtr &main_fg, const CNodePtr &if_cnode, + CNodePtr *else_partial_cnode) { + return CreateIfPartialNode(main_fg, kIfElseIndex, if_cnode, else_partial_cnode); +} + +int IfToPartialPass::CreateIfThenPartialNode(const FuncGraphPtr &main_fg, const CNodePtr &if_cnode, + CNodePtr *then_partial_cnode) { + return CreateIfPartialNode(main_fg, kIfThenIndex, if_cnode, then_partial_cnode); +} + +int IfToPartialPass::ProcessIfOp(const FuncGraphPtr &fg, const std::set &if_nodes) { + if (if_nodes.empty()) { + MS_LOG(INFO) << "not found if, no need to process."; + return RET_SUCCESS; + } + for (const auto &if_node : if_nodes) { + auto if_cnode = if_node->cast(); + MS_ASSERT(if_cnode != nullptr); + if (if_cnode->size() < kIfMinInputSize) { + MS_LOG(ERROR) << "if input is not right."; + return RET_FAILED; + } + + CNodePtr then_partial_cnode = nullptr; + int ret = CreateIfThenPartialNode(fg, if_cnode, &then_partial_cnode); + if (ret != RET_SUCCESS) { + MS_LOG(ERROR) << "if create then partial cnode failed, ret: " << ret; + return ret; + } + CNodePtr else_partial_cnode = nullptr; + ret = CreateIfElsePartialNode(fg, if_cnode, &else_partial_cnode); + if (ret != RET_SUCCESS) { + MS_LOG(ERROR) << "if create else partial cnode failed, ret: " << ret; + return ret; + } + + // create switch cnode + ValueNodePtr switch_anf_primitive = lite::GetSwitchAnfPrim(); + if (switch_anf_primitive == nullptr) { + MS_LOG(ERROR) << "GetSwitchAnfPrim failed."; + return RET_FAILED; + } + + // insert switch node + std::vector switch_node_inputs = {switch_anf_primitive, if_cnode->input(kIfCondIndex), + then_partial_cnode, else_partial_cnode}; + auto switch_cnode = fg->NewCNode(switch_node_inputs); + MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_FAILED, "NewCNode failed"); + switch_cnode->set_fullname_with_scope("if-Switch-" + if_cnode->fullname_with_scope() + "-" + + fg->get_attr("graph_name")->ToString()); + switch_cnode->set_abstract(if_cnode->abstract()); + auto manager = Manage(fg, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "Manager is nullptr!"; + return false; + } + manager->Replace(if_node, switch_cnode); + UpdateManager(fg); + fg->DropNode(if_cnode); + } + return RET_SUCCESS; +} + +int IfToPartialPass::ProcessControlOp(const FuncGraphPtr &fg) { + if (fg == nullptr) { + MS_LOG(ERROR) << "fg is nullptr."; + return RET_FAILED; + } + + std::set control_flow_nodes; + int ret = ColloctIfNodes(fg, &control_flow_nodes); + if (ret != RET_SUCCESS) { + MS_LOG(ERROR) << "SplitGraph failed, ret: " << ret; + return ret; + } + + if (control_flow_nodes.empty()) { + MS_LOG(INFO) << "not found control flow op, no need to process."; + return RET_SUCCESS; + } + + ret = ProcessIfOp(fg, control_flow_nodes); + if (ret != RET_SUCCESS) { + MS_LOG(ERROR) << "ProcessIfOp failed."; + return ret; + } + return RET_SUCCESS; +} + +bool IfToPartialPass::Run(const FuncGraphPtr &fg) { + MS_ASSERT(fg != nullptr); + to_process_q.push_back(fg); + while (!to_process_q.empty()) { + auto cur_fg = to_process_q.front(); + auto cur_fg_name = cur_fg->get_attr("graph_name")->ToString(); + int ret = ProcessControlOp(cur_fg); + if (ret != RET_SUCCESS) { + MS_LOG(ERROR) << "ProcessControlOp for graph: " << cur_fg_name << " failed."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(ret); + return false; + } + to_process_q.pop_front(); + } + return true; +} +} // namespace mindspore::opt diff --git a/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.h b/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.h new file mode 100644 index 00000000..9ecb3de7 --- /dev/null +++ b/mindspore-lite/tools/optimizer/graph/if_to_partial_pass.h @@ -0,0 +1,69 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef IF_TO_PARTIAL_PASS_H +#define IF_TO_PARTIAL_PASS_H + +#include +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "include/backend/optimizer/pass.h" + +namespace mindspore::opt { +class IfToPartialPass : public Pass { + public: + IfToPartialPass() : Pass("if_to_partial_pass") {} + ~IfToPartialPass() override = default; + bool Run(const FuncGraphPtr &fg) override; + + private: + int ColloctIfNodes(const FuncGraphPtr &fg, std::set *control_flow_node); + + // process if + int CreateIfPartialNodeExternalInputs(const CNodePtr &if_cnode, const FuncGraphPtr &partial_fg, + std::vector *then_partial_cnode_inputs); + int CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &index, const CNodePtr &if_cnode, + CNodePtr *then_partial_cnode); + int CreateIfThenPartialNode(const FuncGraphPtr &main_fg, const CNodePtr &if_cnode, CNodePtr *then_partial_cnode); + int CreateIfElsePartialNode(const FuncGraphPtr &main_fg, const CNodePtr &if_cnode, CNodePtr *else_partial_cnode); + int ProcessIfOp(const FuncGraphPtr &fg, const std::set &if_nodes); + + int ProcessControlOp(const FuncGraphPtr &fg); + + const size_t kCNodePrimIndex = 0; + const size_t kCNodeFirstInputIndex = 1; + const size_t kCNodeSecondInputIndex = 2; + + const size_t kGetItemInputSize = 3; + const size_t kPartialFirstInputSize = 2; + + const size_t kWhileMinInputSize = 3; + const size_t kWhileCondIndex = 1; + const size_t kWhileBodyIndex = 2; + + const size_t kIfMinInputSize = 4; + const size_t kIfThenIndex = 1; + const size_t kIfElseIndex = 2; + const size_t kIfCondIndex = 3; + + std::deque to_process_q{}; +}; +} // namespace mindspore::opt + +#endif // IF_TO_PARTIAL_PASS_H diff --git a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc index 81ed965a..5de8253c 100644 --- a/mindspore-lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/infershape_pass.cc @@ -329,6 +329,7 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); + size_t start_index = CheckPrimitiveType(cnode, prim::kPrimWhile) ? kInputSizeThree : kInputSizeFour; auto sub_inputs = sub_graph->get_inputs(); sub_inputs_map_[sub_graph] = sub_inputs; for (auto &node : sub_inputs) { @@ -340,7 +341,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt last_underline = node_name.find_last_of("_"); size_t index = 0; try { - index = static_cast(std::stoi(node_name.substr(last_underline + 1))) + kInputSizeThree; + index = static_cast(std::stoi(node_name.substr(last_underline + 1))) + start_index; } catch (const std::exception &e) { MS_LOG(ERROR) << "Get index failed: " << e.what(); return RET_ERROR; -- Gitee From f78f34af7d928a018660fc4f023fa56a38edb7f2 Mon Sep 17 00:00:00 2001 From: z00621985 Date: Fri, 18 Jul 2025 16:03:06 +0800 Subject: [PATCH 2/3] update submodule --- .gitmodules | 4 ++-- mindspore | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.gitmodules b/.gitmodules index d7f1a58b..4ce542a7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,5 @@ [submodule "mindspore"] path = mindspore - url = https://gitee.com/mindspore/mindspore.git + url = https://gitee.com/zxx_xxz/mindspore.git # shallow = true - branch = r2.7.rc1 \ No newline at end of file + branch = zxx_control_flow_for_lite diff --git a/mindspore b/mindspore index 2365375a..08fec943 160000 --- a/mindspore +++ b/mindspore @@ -1 +1 @@ -Subproject commit 2365375a9d84065f7354601cb3b48ececb477bc8 +Subproject commit 08fec94398075439cf2761dfb9de685c034e8151 -- Gitee From fd7df700c71c4e8ed5bfa60f6d30cab5d8f29028 Mon Sep 17 00:00:00 2001 From: z00621985 Date: Fri, 18 Jul 2025 16:10:39 +0800 Subject: [PATCH 3/3] update test --- mindspore-lite/test/config_level0/models_ascend_cloud.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore-lite/test/config_level0/models_ascend_cloud.cfg b/mindspore-lite/test/config_level0/models_ascend_cloud.cfg index c3c553b6..cde0b250 100644 --- a/mindspore-lite/test/config_level0/models_ascend_cloud.cfg +++ b/mindspore-lite/test/config_level0/models_ascend_cloud.cfg @@ -103,7 +103,7 @@ cbg_ai_ocr_language_classify_latin.pb;1:data;2,48,1,50;;offline_resize 5 # cbg_ai_ocr_recognize_chinese_english.pb;1:input_0;1,2048,2048,1;;offline_resize 5 cbg_ai_ocr_recognize_chinese_english_vertical.pb;1:input_0;1,2048,2048,1;;offline_resize 5 cbg_ai_ocr_recognize_japanese_korean.pb;1:input_0;1,2048,2048,1;;offline_resize 5 -cbg_ai_text_filing_inpainting.pb;2:input_images,input_masks;1,32,32,3:1,32,32,1;;offline_resize 10 +cbg_ai_text_filing_inpainting.pb;2:input_images,input_masks;1,32,32,3:1,32,32,1;NHWC;offline_resize 10 # open_source_inception_resnet_v2.pb;1:input;2,299,299,3;;offline_resize 5 # open_source_mobilenet_v1_10_224_frozen.pb;1:input;2,224,224,3;;offline_resize 5 G7.pb;1:Placeholder;1,640,480,3;NHWC;offline_resize 5 -- Gitee