diff --git a/mindspore-src/source/mindspore/lite/src/litert/c_api/model_c.cc b/mindspore-src/source/mindspore/lite/src/litert/c_api/model_c.cc index 8e9863a846881510cd3a53be8ca8a7b256460145..3922210ae3397244bc700cc124d78ca07444a6da 100644 --- a/mindspore-src/source/mindspore/lite/src/litert/c_api/model_c.cc +++ b/mindspore-src/source/mindspore/lite/src/litert/c_api/model_c.cc @@ -59,6 +59,7 @@ class ModelC { mindspore::MSKernelCallBack TransCallBack(const OH_AI_KernelCallBack &ms_callback); std::shared_ptr model_; std::shared_ptr context_; + std::atomic is_model_in_use_; private: MSTensor **GetOutputsTensor(size_t *output_num, std::vector *vec_tensors); @@ -171,6 +172,12 @@ void OH_AI_ModelDestroy(OH_AI_ModelHandle *model) { return; } auto impl = static_cast(*model); + auto locked = impl->is_model_in_use_.load(); + if (locked) { + MS_LOG(ERROR) << "Fail to destroy model, model in use"; + return; + } + impl->is_model_in_use_.store(true); delete impl; *model = nullptr; MS_LOG(INFO) << "Destroyed ms model successfully"; @@ -229,8 +236,15 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model } mindspore::ContextC *context = static_cast(model_context); auto impl = static_cast(model); + auto locked = impl->is_model_in_use_.load(); + if (locked) { + MS_LOG(ERROR) << "Fail to build from file, model in use"; + return OH_AI_STATUS_LITE_ERROR; + } + impl->is_model_in_use_.store(true); if (impl->context_.get() != context->context_ && context->owned_by_model_) { MS_LOG(ERROR) << "context is owned by other model."; + impl->is_model_in_use_.store(false); return OH_AI_STATUS_LITE_PARAM_INVALID; } if (impl->context_.get() != context->context_) { @@ -243,6 +257,7 @@ OH_AI_Status OH_AI_ModelBuildFromFile(OH_AI_ModelHandle model, const char *model } else { MS_LOG(ERROR) << "Built ms model from file failed, ret: " << ret; } + impl->is_model_in_use_.store(false); return static_cast(ret.StatusCode()); } @@ -285,10 +300,17 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl return OH_AI_STATUS_LITE_NULLPTR; } auto impl = static_cast(model); + auto locked = impl->is_model_in_use_.load(); + if (locked) { + MS_LOG(ERROR) << "Fail to predict, model in use"; + return OH_AI_STATUS_LITE_ERROR; + } + impl->is_model_in_use_.store(true); size_t input_num; (void)impl->GetInputs(&input_num); if (input_num != inputs.handle_num) { MS_LOG(ERROR) << "Wrong input size."; + impl->is_model_in_use_.store(false); return OH_AI_STATUS_LITE_ERROR; } @@ -299,6 +321,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl ms_tensor_inputs.push_back(*user_input); } else { MS_LOG(ERROR) << "input handle is nullptr."; + impl->is_model_in_use_.store(false); return OH_AI_STATUS_LITE_NULLPTR; } } @@ -315,6 +338,7 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl for (size_t i = 0; i < output_num; i++) { if (outputs->handle_list[i] == nullptr) { MS_LOG(ERROR) << "user provided output array handle_list[" << i << "] is nullptr"; + impl->is_model_in_use_.store(false); return OH_AI_STATUS_LITE_NULLPTR; } ms_tensor_outputs.push_back(*static_cast(outputs->handle_list[i])); @@ -324,15 +348,18 @@ OH_AI_Status OH_AI_ModelPredict(OH_AI_ModelHandle model, const OH_AI_TensorHandl auto ret = impl->model_->Predict(ms_tensor_inputs, &ms_tensor_outputs, before_call_back, after_call_back); if (!ret.IsOk()) { MS_LOG(ERROR) << "Predict fail, ret :" << ret; + impl->is_model_in_use_.store(false); return static_cast(ret.StatusCode()); } if (handle_num == output_num) { + impl->is_model_in_use_.store(false); return OH_AI_STATUS_SUCCESS; } outputs->handle_list = reinterpret_cast(impl->GetOutputs(&(outputs->handle_num))); MS_LOG(INFO) << "Predicted ms model successfully"; + impl->is_model_in_use_.store(false); return static_cast(ret.StatusCode()); }