diff --git a/tf_adapter/python/npu_bridge/npu_init.py b/tf_adapter/python/npu_bridge/npu_init.py index 51a6eed63f6bfebe4e2f8f928638198cadb00b29..b4bfb113f29d0545a88b9b58fe37b318456c0b5c 100644 --- a/tf_adapter/python/npu_bridge/npu_init.py +++ b/tf_adapter/python/npu_bridge/npu_init.py @@ -70,6 +70,7 @@ from hccl.split.api import set_split_strategy_by_size from npu_bridge import embedding as npu_embedding from npu_bridge.profiler import profiler from npu_bridge.npu_cpu import npu_cpu_ops +from npu_bridge.scoped_graph_manager import scoped_graph_manager import tensorflow as tf diff --git a/tf_adapter/tests/st/kernels/testcase/geop_npu_test.cc b/tf_adapter/tests/st/kernels/testcase/geop_npu_test.cc index 330e3a61fba1267687ded6e7d7ef5c9bd50c94c4..216e7bedd82052533cdf0ebd4f31dd0edb2618bd 100644 --- a/tf_adapter/tests/st/kernels/testcase/geop_npu_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/geop_npu_test.cc @@ -514,6 +514,17 @@ TEST_F(GeOpTest, GeOpFuncTestWithLifeCycleControl) { EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == false); } +TEST_F(GeOpTest, GeOpFuncTestEmptyGraphWithLifeCycleControl) { + NpuClose(); + NodeDef node_def; + std::string graph_def_path = "tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt"; + gtl::InlinedVector inputs; + EnableControl(); + EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == true); + Clear(); + EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == false); +} + TEST_F(GeOpTest, GeOpFuncTestWithProfilingDefaultAicMetircs) { NpuClose(); NodeDef node_def; diff --git a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc index 425d3685e1290d1711d48ba23ef7358c9279a86e..fac7ac9c5bc436aa2efcf413ebc51b2fd8e2d6d6 100644 --- a/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/geop_npu_test.cc @@ -261,6 +261,17 @@ TEST_F(GeOpTest, GeOpFuncTestWithLifeCycleControl) { EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == false); } +TEST_F(GeOpTest, GeOpFuncTestEmptyGraphWithLifeCycleControl) { + NpuClose(); + NodeDef node_def; + std::string graph_def_path = "tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt"; + gtl::InlinedVector inputs; + EnableControl(); + EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == true); + Clear(); + EXPECT_TRUE(ScopedGraphManager::Instance().IsControlEnabled() == false); +} + TEST_F(GeOpTest, GeOpFuncTestWithProfilingDefaultAicMetircs) { NpuClose(); NodeDef node_def; diff --git a/tf_adapter/util/scoped_graph_manager.cc b/tf_adapter/util/scoped_graph_manager.cc index 069c383a0aedefea83e279534f3543640b306b3f..13e220086949e473225d065d143b9cd9e52635f1 100644 --- a/tf_adapter/util/scoped_graph_manager.cc +++ b/tf_adapter/util/scoped_graph_manager.cc @@ -39,8 +39,6 @@ void ScopedGraphManager::EnableControl() { void ScopedGraphManager::DisableControl() { std::lock_guard lock(mutex_); graph_life_control_enabled_ = false; - graph_id_ = UINT32_MAX; - tf_session_.clear(); ADP_LOG(INFO) << "[ScopedGraphManager] Set graph_life_control_enabled_ false"; } @@ -68,20 +66,28 @@ bool ScopedGraphManager::SetGraph(const std::string& tf_session, const uint32_t& void ScopedGraphManager::Clear() { ADP_LOG(INFO) << "[ScopedGraphManager] Begin to clear after graph run"; - { - std::lock_guard lock(mutex_); - ge::Session* global_ge_session = nullptr; - std::map global_sess_options; - if (!SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, global_ge_session, global_sess_options)) { - ADP_LOG(WARNING) << "[ScopedGraphManager] Failed to get session for tf_session: " << tf_session_; - } - if (global_ge_session != nullptr) { - global_ge_session->RemoveGraph(graph_id_); - ADP_LOG(INFO) << "[ScopedGraphManager] RemoveGraph success for tf_session: " - << tf_session_ << ", graph_id: " << graph_id_; - } - } DisableControl(); + + std::lock_guard lock(mutex_); + ge::Session* global_ge_session = nullptr; + std::map global_sess_options; + // 空图,未注册图,tf_session_为空 + if (tf_session_.empty()) { + ADP_LOG(WARNING) << "[ScopedGraphManager] No need to RemoveGraph, tf_session is empty"; + graph_id_ = UINT32_MAX; + ADP_LOG(INFO) << "[ScopedGraphManager] Clear finished"; + return; + } + if (!SessionManager::GetInstance().GetOrCreateGeSession(tf_session_, global_ge_session, global_sess_options)) { + ADP_LOG(WARNING) << "[ScopedGraphManager] Failed to get session for tf_session: " << tf_session_; + } + if (global_ge_session != nullptr) { + global_ge_session->RemoveGraph(graph_id_); + ADP_LOG(INFO) << "[ScopedGraphManager] RemoveGraph success for tf_session: " + << tf_session_ << ", graph_id: " << graph_id_; + } + graph_id_ = UINT32_MAX; + tf_session_.clear(); ADP_LOG(INFO) << "[ScopedGraphManager] Clear finished"; } } \ No newline at end of file