diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 3f086db4f39a96acc39064fa7f12f26fabcf35f4..1f637e69db60ce9db467eb63f7dd4a0658f2a227 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -1050,6 +1050,7 @@ Status GeOp::DoGraphParser(ge::ComputeGraphPtr &compute_graph, FunctionLibraryDe } void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { + run_mtx_.lock(); // ctx is not nullptr OP_REQUIRES_ASYNC(ctx, init_flag_, errors::InvalidArgument("GeOp not Initialize success."), done); if (!sess_init_flag_) { @@ -1318,9 +1319,10 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { } int64 run_start_time = InferShapeUtil::GetCurrentTimestap(); - auto callback = [done, ctx, run_start_time](ge::Status ge_status, std::vector &outputs) { + auto callback = [done, ctx, run_start_time, this](ge::Status ge_status, std::vector &outputs) { if (ge_status == ge::SUCCESS) { if (BuildOutputTensorInfo(ctx, outputs) != Status::OK()) { + run_mtx_.unlock(); ADP_LOG(FATAL) << ctx->op_kernel().name() << " GEOP::DoRunAsync get output failed."; std::string error_message = ge::GEGetErrorMsg(); std::stringstream ss; @@ -1335,6 +1337,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(WARNING) << "[GEOP] Out of range: End of sequence."; LOG(WARNING) << "[GEOP] Out of range: End of sequence."; } else if (ge_status != ge::SUCCESS) { + run_mtx_.unlock(); std::this_thread::sleep_for(std::chrono::milliseconds(kFatalSleepTime)); ADP_LOG(FATAL) << ctx->op_kernel().name() << "GEOP::::DoRunAsync Failed"; std::string error_message = ge::GEGetErrorMsg(); @@ -1348,6 +1351,7 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { ADP_LOG(INFO) << "[GEOP] RunGraphAsync callback, status:" << ge_status << ", kernel_name:" << ctx->op_kernel().name() << "[ " << (run_end_time - run_start_time) << "us]"; done(); + run_mtx_.unlock(); }; // call ge session runGraphAsync api diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index bf102b2ff9990790a9f848ffedcf54a7d5ae368e..01ca513c062946336881511b4c4ab3d57113f42d 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -247,6 +247,7 @@ public: AoeSetTuningGraphInputFunc aoe_set_tuning_graph_input_; // accelerate train AccelerateInfo accelerate_info_; + std::mutex run_mtx_; }; } // namespace tensorflow #endif // TENSORFLOW_KERNELS_GEOP_NPU_H_