From cf8bb63aa8397400ec5d625b0e46d738afdfa7e6 Mon Sep 17 00:00:00 2001 From: guopeian Date: Thu, 16 Feb 2023 11:37:45 +0800 Subject: [PATCH] fix --- .../inc/framework/omg/parser/model_parser.h | 19 +++++++++++ tf_adapter/kernels/geop_npu.cc | 33 +++++++++---------- tf_adapter/kernels/geop_npu.h | 2 +- .../depends/ge_runner/src/ge_runner_stub.cc | 18 +++++++--- tf_adapter_2.x/npu_device/core/npu_device.cpp | 6 ++-- .../framework/omg/parser/model_parser.h | 2 +- .../tests/stub/include/stub/defines.h | 2 ++ tf_adapter_2.x/tests/stub/parser_stub.cpp | 16 +++++++-- 8 files changed, 70 insertions(+), 28 deletions(-) diff --git a/inc/graphengine/inc/framework/omg/parser/model_parser.h b/inc/graphengine/inc/framework/omg/parser/model_parser.h index 4902339d5..25e2597bb 100644 --- a/inc/graphengine/inc/framework/omg/parser/model_parser.h +++ b/inc/graphengine/inc/framework/omg/parser/model_parser.h @@ -41,6 +41,8 @@ using GetGraphCallback = std::function; +using GetGraphCallbackV3 = std::function &partitioned_serialized, + std::map &const_value_map)>; class GE_FUNC_VISIBILITY ModelParser { public: ModelParser() {} @@ -177,6 +179,23 @@ class GE_FUNC_VISIBILITY ModelParser { ge::ComputeGraphPtr &graph) { return UNSUPPORTED; } + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] partitioned_serialized partitioned serialized network model + * @param [in] const_value_map const value map, key: constant node name value: serialized constant output tensor + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const std::vector &partitioned_serialized, + const std::map &const_value_map, + GetGraphCallbackV3 callback, + ge::ComputeGraphPtr &graph) { + return UNSUPPORTED; + } }; } // namespace domi diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 43fb37a51..cab79c201 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -815,8 +815,11 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { std::map const_value_map; std::vector partition_graph; OP_REQUIRES_OK_ASYNC(ctx, SeparateGraphDef(ori_graph_def, partition_graph, const_value_map), done); - auto build_sub_graph = [this, flib_def](const std::string &graph) -> std::string { - return this->BuildSubGraph(flib_def, graph); + auto build_sub_graph = [this, flib_def](const std::string &graph, std::vector &partition_graph, + std::map &const_value_map) -> bool { + GraphDef sub_graph_def; + this->BuildSubGraphDef(flib_def, graph, sub_graph_def); + return SeparateGraphDef(sub_graph_def, partition_graph, const_value_map).ok(); }; ge::Status status = model_parser->ParseProtoWithSubgraph(partition_graph, const_value_map, build_sub_graph, compute_graph); @@ -1334,7 +1337,7 @@ Status GeOp::SeparateGraphDef(GraphDef &ori_graph_def, partition_graph.push_back(graph_def_str); return Status::OK(); } - LOG(INFO) << "GraphDef is beyond 2G, which is need separate"; + LOG(INFO) << "GraphDef is beyond 2G, which is need separate weight"; for (NodeDef &node : *ori_graph_def.mutable_node()) { if (node.op() == "Const") { std::string node_name = node.name(); @@ -1581,7 +1584,9 @@ int GeOp::RunTuning(std::vector &input_vec, std::vector &inp } auto build_sub_graph = [this, flib_def](const std::string &graph) -> std::string { - return this->BuildSubGraph(flib_def, graph); + GraphDef sub_graph_def; + this->BuildSubGraphDef(flib_def, graph, sub_graph_def); + return sub_graph_def.SerializeAsString(); }; ge::Status status = model_parser->ParseProtoWithSubgraph(ori_graph_def.SerializeAsString(), build_sub_graph, compute_graph); @@ -1668,19 +1673,19 @@ int GeOp::RunTuning(std::vector &input_vec, std::vector &inp return 0; } -std::string GeOp::BuildSubGraph(FunctionLibraryDefinition *flib_def, const std::string &graph) { +void GeOp::BuildSubGraphDef(FunctionLibraryDefinition *flib_def, const std::string &graph, GraphDef &sub_graph_def) { ADP_LOG(INFO) << "[GEOP] build_sub_graph enter, sub graph name is " << graph; const FunctionDef *func_def = flib_def->Find(graph); if (func_def == nullptr) { ADP_LOG(ERROR) << "[GEOP] Sub graph not found in library, sub graph name is " << graph; - return ""; + return; } // get infershape Graph subgraph(flib_def); Status status = InferShapeUtil::GetSubGraphFromFunctionDef(*flib_def, *func_def, &subgraph); if (status != Status::OK()) { ADP_LOG(ERROR) << "[GEOP] Get subgraph from functiondef fail:" << status.error_message(); - return ""; + return; } ADP_LOG(INFO) << "[GEOP] Get subgraph from functiondef success."; std::string enable_force_v2_control; @@ -1699,23 +1704,17 @@ std::string GeOp::BuildSubGraph(FunctionLibraryDefinition *flib_def, const std:: << " Generate desc failed in subgraph."; } } - unique_ptr sub_graph_def(new (std::nothrow) GraphDef()); - if (sub_graph_def == nullptr) { - ADP_LOG(ERROR) << "[GEOP] Malloc memory for subgraph def fail."; - return ""; - } - subgraph.ToGraphDef(sub_graph_def.get()); + subgraph.ToGraphDef(&sub_graph_def); if (enable_force_v2_control == "1") { - sub_graph_def->release_library(); - sub_graph_def->mutable_versions()->clear_min_consumer(); + sub_graph_def.release_library(); + sub_graph_def.mutable_versions()->clear_min_consumer(); } if (kDumpGraph) { const std::string pbtxt_path = GetDumpPath() + "TF_Subgraph_" + graph.c_str() + ".pbtxt"; - (void) WriteTextProto(Env::Default(), pbtxt_path, *sub_graph_def); + (void) WriteTextProto(Env::Default(), pbtxt_path, sub_graph_def); } ADP_LOG(INFO) << "[GEOP] build_sub_graph exit, sub graph name is " << graph; - return sub_graph_def->SerializeAsString(); } void GeOp::SetDynamicInput() { diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index e29982a01..b9591e894 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -115,7 +115,7 @@ private: int RunTuning(std::vector &input_vec, std::vector &inputs, const OpKernelContext *const ctx); - std::string BuildSubGraph(FunctionLibraryDefinition *flib_def, const std::string &graph); + void BuildSubGraphDef(FunctionLibraryDefinition *flib_def, const std::string &graph, GraphDef &sub_graph_def); void SetDynamicInput(); diff --git a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc index c74aa7a82..ed30a74d1 100644 --- a/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc +++ b/tf_adapter/tests/depends/ge_runner/src/ge_runner_stub.cc @@ -114,7 +114,7 @@ class TensorFlowModelParser : public domi::ModelParser { Status ParseProtoWithSubgraph(const std::vector &partitioned_serialized, const std::map &const_value_map, - domi::GetGraphCallbackV2 callback, + domi::GetGraphCallbackV3 callback, ge::ComputeGraphPtr &graph) override; ge::DataType ConvertToGeDataType(const uint32_t type) override; @@ -143,14 +143,24 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &serializ ge::ComputeGraphPtr &graph) { std::vector partitioned_serialized{serialized_proto}; std::map const_value_map; - return ParseProtoWithSubgraph(partitioned_serialized, const_value_map, callback, graph); + auto callback_v3 = [callback] (const std::string &graph, std::vector &partition_graph, + std::map &const_value_map) -> bool { + (void)partition_graph; + (void)const_value_map; + std::string graph_def = callback(graph); + partition_graph.push_back(graph_def); + return true; + }; + return ParseProtoWithSubgraph(partitioned_serialized, const_value_map, callback_v3, graph); } Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::vector &partitioned_serialized, const std::map &const_value_map, - domi::GetGraphCallbackV2 callback, + domi::GetGraphCallbackV3 callback, ge::ComputeGraphPtr &graph) { - callback("finall_branch1_Y3CNZMF9Vv8"); + std::map subgraph_const_value; + std::vector subgraph_partitioned_serialized; + callback("finall_branch1_Y3CNZMF9Vv8", subgraph_partitioned_serialized, subgraph_const_value); return ge::SUCCESS; } diff --git a/tf_adapter_2.x/npu_device/core/npu_device.cpp b/tf_adapter_2.x/npu_device/core/npu_device.cpp index 424cdae7c..e7b4e853d 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_device.cpp @@ -946,7 +946,8 @@ tensorflow::Status NpuDevice::TransTfGraph2GeGraph(TFE_Context *context, const s domi::ModelParserFactory::Instance()->CreateModelParser(domi::FrameworkType::TENSORFLOW); NPU_REQUIRES(parser != nullptr, tensorflow::errors::Internal("NPU Create new tensorflow model parser failed")); - auto request_subgraph = [name, context](const std::string &fn) -> std::string { + auto request_subgraph = [name, context](const std::string &fn, std::vector &partition_graph, + std::map &const_value_map) -> bool { DLOG() << "Tensorflow model parser requesting subgraph " << fn << " for ge graph " << name; tensorflow::FunctionLibraryDefinition *lib_def = npu::UnwrapCtx(context)->FuncLibDef(); const tensorflow::FunctionDef *fdef = lib_def->Find(fn); @@ -974,7 +975,8 @@ tensorflow::Status NpuDevice::TransTfGraph2GeGraph(TFE_Context *context, const s if (kDumpExecutionDetail || kDumpGraph) { WriteTextProto(tensorflow::Env::Default(), name + "_subgraph_" + fn + ".pbtxt", graph->ToGraphDefDebug()); } - return graph->ToGraphDefDebug().SerializeAsString(); + tensorflow::GraphDef sub_graph_def = graph->ToGraphDefDebug(); + return SeparateGraphDef(&sub_graph_def, partition_graph, const_value_map).ok(); }; std::map const_value_map; diff --git a/tf_adapter_2.x/tests/stub/include/framework/omg/parser/model_parser.h b/tf_adapter_2.x/tests/stub/include/framework/omg/parser/model_parser.h index c7d9b4d07..6793c7995 100644 --- a/tf_adapter_2.x/tests/stub/include/framework/omg/parser/model_parser.h +++ b/tf_adapter_2.x/tests/stub/include/framework/omg/parser/model_parser.h @@ -14,7 +14,7 @@ class GE_FUNC_VISIBILITY ModelParser { Status ParseProtoWithSubgraph(const std::vector &partitioned_serialized, const std::map &const_value_map, - GetGraphCallbackV2 callback, + GetGraphCallbackV3 callback, ge::ComputeGraphPtr &graph); }; } // namespace domi diff --git a/tf_adapter_2.x/tests/stub/include/stub/defines.h b/tf_adapter_2.x/tests/stub/include/stub/defines.h index 5be2c1243..e6228a447 100644 --- a/tf_adapter_2.x/tests/stub/include/stub/defines.h +++ b/tf_adapter_2.x/tests/stub/include/stub/defines.h @@ -309,6 +309,8 @@ const graphStatus GRAPH_NODE_NEED_REPASS = 50331647; namespace domi { using Status = uint32_t; using GetGraphCallbackV2 = std::function; +using GetGraphCallbackV3 = std::function &partitioned_serialized, + std::map &const_value_map)>; enum FrameworkType { CAFFE = 0, MINDSPORE = 1, diff --git a/tf_adapter_2.x/tests/stub/parser_stub.cpp b/tf_adapter_2.x/tests/stub/parser_stub.cpp index a1a081556..24c5cf482 100644 --- a/tf_adapter_2.x/tests/stub/parser_stub.cpp +++ b/tf_adapter_2.x/tests/stub/parser_stub.cpp @@ -82,12 +82,20 @@ Status ModelParser::ParseProtoWithSubgraph(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) { std::vector partitioned_serialized{serialized_proto}; std::map const_value_map; - return ParseProtoWithSubgraph(partitioned_serialized, const_value_map, callback, graph); + auto callback_v3 = [callback] (const std::string &graph, std::vector &partition_graph, + std::map &const_value_map) -> bool { + (void)partition_graph; + (void)const_value_map; + std::string graph_def = callback(graph); + partition_graph.push_back(graph_def); + return true; + }; + return ParseProtoWithSubgraph(partitioned_serialized, const_value_map, callback_v3, graph); } Status ModelParser::ParseProtoWithSubgraph(const std::vector &partitioned_serialized, const std::map &const_value_map, - GetGraphCallbackV2 callback, + GetGraphCallbackV3 callback, ge::ComputeGraphPtr &graph) { std::string graph_def_str = partitioned_serialized[0]; tensorflow::GraphDef graph_def; @@ -117,7 +125,9 @@ Status ModelParser::ParseProtoWithSubgraph(const std::vector &parti for (const auto &node : graph->graph->op_nodes()) { for (const auto &attr : node->attrs()) { if (attr.second.has_func()) { - callback(attr.second.func().name()); + std::map subgraph_const_value; + std::vector subgraph_partitioned_serialized; + callback(attr.second.func().name(), subgraph_partitioned_serialized, subgraph_const_value); } } } -- Gitee