From 87f84142b71054b17741de04cd83386883232847 Mon Sep 17 00:00:00 2001 From: l00278812 Date: Thu, 12 Jan 2023 14:36:08 +0800 Subject: [PATCH] fix bug for trans resource --- ...trans_resource_input_to_node_optimizer.cpp | 44 +++++++++++++-- tf_adapter_2.x/tests/st/adapter2_st.py | 53 +++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/tf_adapter_2.x/npu_device/core/optimizers/runtime/npu_trans_resource_input_to_node_optimizer.cpp b/tf_adapter_2.x/npu_device/core/optimizers/runtime/npu_trans_resource_input_to_node_optimizer.cpp index 1ababd8cb..64787cd1d 100644 --- a/tf_adapter_2.x/npu_device/core/optimizers/runtime/npu_trans_resource_input_to_node_optimizer.cpp +++ b/tf_adapter_2.x/npu_device/core/optimizers/runtime/npu_trans_resource_input_to_node_optimizer.cpp @@ -26,7 +26,7 @@ tensorflow::Status TransResourceInput2Node(TFE_Context *context, tensorflow::Gra bool is_while_body_graph = false); tensorflow::Status TransFunctionDef(TFE_Context *context, const std::string &func_name, - const std::string &new_func_name, + std::string &new_func_name, std::map &node_substitutes, bool is_while_body_graph = false) { npu::OptimizeStageGraphDumper dumper("Function." + func_name); @@ -60,13 +60,21 @@ tensorflow::Status TransFunctionDef(TFE_Context *context, const std::string &fun }; dumper.Dump("after_trans_resource", fbody->graph->ToGraphDefDebug()); + static int64_t unique_name_index = 0; + new_func_name = func_name + "_npu_" + std::to_string(unique_name_index++); NPU_REQUIRES_OK(tensorflow::GraphToFunctionDef(*fbody->graph, new_func_name, lookup, &optimized_fdef)); - NPU_REQUIRES_OK(lib_def->RemoveFunction(new_func_name)); NPU_REQUIRES_OK(lib_def->AddFunctionDef(optimized_fdef)); DLOG() << "Finish trans function " << func_name << " to " << new_func_name; return tensorflow::Status::OK(); } +void UpdateFuncName(tensorflow::Node *node, const std::string &attr_name, const std::string &new_name) { + tensorflow::AttrValue attr_value; + attr_value.mutable_func()->set_name(new_name); + node->ClearAttr(attr_name); + node->AddAttr(attr_name, attr_value); +} + tensorflow::Status TransWhileNode(TFE_Context *context, tensorflow::Graph *graph, tensorflow::Node *node) { DLOG() << "Start trans node " << node->name() << std::endl << node->DebugString(); std::map substitutes; @@ -87,9 +95,12 @@ tensorflow::Status TransWhileNode(TFE_Context *context, tensorflow::Graph *graph std::string body = node->attrs().Find("body")->func().name(); DLOG() << "Trans cond function " << cond << " of node " << node->name(); - (void)TransFunctionDef(context, cond, cond, substitutes); + std::string new_func_name; + (void)TransFunctionDef(context, cond, new_func_name, substitutes); + UpdateFuncName(node, "cond", new_func_name); DLOG() << "Trans body function " << body << " of node " << node->name(); - (void)TransFunctionDef(context, body, body, substitutes, true); + (void)TransFunctionDef(context, body, new_func_name, substitutes, true); + UpdateFuncName(node, "body", new_func_name); tensorflow::NodeDef ndef = node->def(); auto copied_type_attr = ndef.attr().at("T"); // Copy origin attr @@ -173,19 +184,42 @@ tensorflow::Status TransHasSubgraphNode(TFE_Context *context, tensorflow::Graph std::vector functions; if (node->IsIfNode()) { + DLOG() << "Start trans if node " << node->name() << std::endl; functions.emplace_back(node->attrs().Find("then_branch")->func().name()); functions.emplace_back(node->attrs().Find("else_branch")->func().name()); } else if (node->IsCaseNode()) { + DLOG() << "Start trans case node " << node->name() << std::endl; for (const auto &f : node->attrs().Find("branches")->list().func()) { functions.emplace_back(f.name()); } } else { + DLOG() << "Start trans f node " << node->name() << std::endl; functions.emplace_back(node->attrs().Find("f")->func().name()); } + int32_t case_index = 0; for (auto &fn : functions) { DLOG() << "Trans function " << fn << " of node " << node->name(); - (void)TransFunctionDef(context, fn, fn, substitutes); + std::string new_fn_name; + (void)TransFunctionDef(context, fn, new_fn_name, substitutes); + if (node->IsIfNode()) { + if (node->attrs().Find("then_branch")->func().name() == fn) { + UpdateFuncName(node, "then_branch", new_fn_name); + DLOG() << "After tran function " << node->name() << std::endl << node->attrs().Find("then_branch")->func().name(); + } else { + UpdateFuncName(node, "else_branch", new_fn_name); + DLOG() << "After tran function " << node->name() << std::endl << node->attrs().Find("else_branch")->func().name(); + } + } else if (node->IsCaseNode()) { + tensorflow::AttrValue attr_value; + attr_value.mutable_list()->mutable_func(case_index++)->set_name(new_fn_name); + node->ClearAttr("branches"); + node->AddAttr("branches", attr_value); + DLOG() << "After tran function " << node->name() << std::endl << node->attrs().Find("branches")->func().name(); + } else { + UpdateFuncName(node, "f", new_fn_name); + DLOG() << "After tran function " << node->name() << std::endl << node->attrs().Find("f")->func().name(); + } } tensorflow::NodeDef ndef = node->def(); diff --git a/tf_adapter_2.x/tests/st/adapter2_st.py b/tf_adapter_2.x/tests/st/adapter2_st.py index dda926e5b..fa92efb4d 100644 --- a/tf_adapter_2.x/tests/st/adapter2_st.py +++ b/tf_adapter_2.x/tests/st/adapter2_st.py @@ -348,6 +348,59 @@ class Adapter2St(unittest.TestCase): f(iterator) time.sleep(5) + def test_switch_case1(self): + @tf.function + def my_add(x, y): + return tf.add(x, y) + + @tf.function + def f0(): + return tf.constant(0) + + @tf.function + def f1(): + return tf.constant(1) + + @tf.function + def f2(): + return tf.constant(2) + + @tf.function + def f3(index): + return index + + x = tf.constant(1) + @tf.function + def f(index): + return my_add(tf.switch_case(f3(index), {0 : f1, 1 : f2}, default = f2), x) + + v0 = f(tf.constant(0)) + v1 = f(tf.constant(1)) + v2 = f(tf.constant(99)) + self.assertTrue(tensor_equal(v0, tf.constant(1))) + self.assertTrue(tensor_equal(v1, tf.constant(2))) + self.assertTrue(tensor_equal(v2, tf.constant(3))) + + def test_case1(self): + @tf.function + def my_add(x, y): + return tf.add(x, y) + + @tf.function + def f1(): + return tf.constant(1) + + @tf.function + def f2(): + return tf.constant(2) + + @tf.function + def f(): + return my_add(tf.case([(tf.less(0, 1), f1)], default=f2), tf.case([(tf.less(1, 0), f1)], default=f2)) + + v0 = f() + self.assertTrue(tensor_equal(v0, tf.constant(3))) + def test_dropout_v3(self): @tf.function def f(x): -- Gitee