From c25b76949e5f44a3c9e6feb00b1da61f6591cc36 Mon Sep 17 00:00:00 2001 From: yiguangzheng Date: Sat, 20 Sep 2025 14:09:27 +0800 Subject: [PATCH] feat: ascend_context timeout option --- mindspore-lite/src/common/common.h | 3 + .../delegate/ascend_acl/acl_graph_executor.cc | 23 +++- .../delegate/ascend_acl/acl_graph_executor.h | 1 + .../delegate/ascend_acl/acl_model_options.h | 18 +++ .../delegate/ascend_acl/model_process.cc | 65 +++++++++- .../delegate/ascend_acl/model_process.h | 3 + .../test/st/python/python_api/conftest.py | 100 +++++++++++++++ .../test/st/python/python_api/pytest.ini | 3 + .../python_api/test_stream_sync_timeout.py | 120 ++++++++++++++++++ .../test/st/python/python_api/utils.py | 42 ++++++ .../st/scripts/ascend/run_cloud_arm_a2.sh | 2 + 11 files changed, 372 insertions(+), 8 deletions(-) create mode 100644 mindspore-lite/test/st/python/python_api/conftest.py create mode 100644 mindspore-lite/test/st/python/python_api/pytest.ini create mode 100644 mindspore-lite/test/st/python/python_api/test_stream_sync_timeout.py create mode 100644 mindspore-lite/test/st/python/python_api/utils.py diff --git a/mindspore-lite/src/common/common.h b/mindspore-lite/src/common/common.h index cd229acd..eeac523d 100644 --- a/mindspore-lite/src/common/common.h +++ b/mindspore-lite/src/common/common.h @@ -119,6 +119,9 @@ static const char *const kVariableWeightsFile = "variable_weights_file"; static const char *const kMaxWeightBatch = "max_weight_batch"; static const char *const kStreamLabelFile = "stream_label_file"; static const char *const kSplitNodeName = "split_node_name"; +static const char *const kTimeout = "timeout"; +static constexpr int32_t kModelExecStreamSyncTimeoutIgnoreValue = 0; +static constexpr int32_t kModelExecStreamSyncTimeoutUnlimitedValue = -1; // ge options static const char *const kGeSessionOptionsSection = "ge_session_options"; static const char *const kGeGraphOptionsSection = "ge_graph_options"; diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc index 923dbf59..cc203208 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.cc @@ -124,6 +124,26 @@ void AclGraphExecutor::GetShareMemInfos(std::shared_ptr acl_opt } } +Status AclGraphExecutor::GetExecConfig(const std::shared_ptr &acl_options_ptr) { + MS_CHECK_TRUE_MSG(acl_options_ptr != nullptr, kLiteError, "Acl options ptr is nullptr."); + std::string stream_sync_timeout_str = GetConfigOption(lite::kAscendContextSection, lite::kTimeout); + if (stream_sync_timeout_str.empty()) { + return kSuccess; + } + int32_t stream_sync_timeout = INT32_MIN; + if (!lite::ConvertStrToInt(stream_sync_timeout_str, &stream_sync_timeout)) { + MS_LOG(ERROR) << "Convert stream_sync_timeout_str to int failed, got: " << stream_sync_timeout_str; + return kLiteInputParamInvalid; + } + if (stream_sync_timeout < lite::kModelExecStreamSyncTimeoutUnlimitedValue || + stream_sync_timeout == lite::kModelExecStreamSyncTimeoutIgnoreValue) { + MS_LOG(ERROR) << "stream_sync_timeout should be -1 or positive integer, but got " << stream_sync_timeout; + return kLiteInputParamInvalid; + } + acl_options_ptr->model_exec_config.stream_sync_timeout = ModelExecConfigAttr(stream_sync_timeout); + return kSuccess; +} + std::shared_ptr AclGraphExecutor::GenAclOptions() { auto acl_options_ptr = std::make_shared(); MS_CHECK_TRUE_MSG(acl_options_ptr != nullptr, nullptr, "Acl options make shared failed."); @@ -157,7 +177,8 @@ std::shared_ptr AclGraphExecutor::GenAclOptions() { if (output_name_str != "") { acl_options_ptr->output_names = lite::StrSplit(output_name_str, ","); } - + auto ret = GetExecConfig(acl_options_ptr); + MS_CHECK_TRUE_MSG(ret == kSuccess, nullptr, "Get exec config failed."); return acl_options_ptr; } diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h index c0abe87e..25bedfbc 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_graph_executor.h @@ -62,6 +62,7 @@ class AclGraphExecutor : public LiteGraphExecutor { private: Status BuildCustomAscendKernel(const CNodePtr &cnode); void GetShareMemInfos(std::shared_ptr acl_options_ptr); + Status GetExecConfig(const std::shared_ptr &acl_options_ptr); std::shared_ptr GenAclOptions(); bool GetDeviceID(int32_t *device_id); Status GetOutputTensors(const std::vector &output_names, std::vector *output_tensors); diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h index e7eee260..8aad509e 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/acl_model_options.h @@ -26,6 +26,23 @@ #include "acl/acl_mdl.h" namespace mindspore { + +template +class ModelExecConfigAttr { + using type = T; + T value_; + + public: + ModelExecConfigAttr() = default; + explicit ModelExecConfigAttr(const T &value) : value_(value) {} + T &Value() { return value_; } + static constexpr size_t Size() { return sizeof(T); } +}; + +struct ModelExecConfig { + ModelExecConfigAttr stream_sync_timeout{0}; +}; + struct AclModelOptions { int32_t device_id; std::string dump_path; @@ -41,6 +58,7 @@ struct AclModelOptions { std::vector output_names; std::string pids = ""; uint64_t sharable_handle = 0; + ModelExecConfig model_exec_config; AclModelOptions() : device_id(0) {} }; diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc index a8fad1b9..0a634661 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.cc @@ -20,13 +20,11 @@ #include #include #include -#include -#include -#include #include #include #include #include +#include "common/common.h" #include "common/log_adapter.h" #include "src/common/utils.h" #include "src/common/log_util.h" @@ -111,6 +109,11 @@ static std::string ShapeToString(const std::vector &shape) { return result; } +bool CheckModelExecuteV2Support() { + return HAS_ASCEND_API(aclmdlExecuteV2) && HAS_ASCEND_API(aclmdlCreateExecConfigHandle) && + HAS_ASCEND_API(aclmdlDestroyExecConfigHandle) && HAS_ASCEND_API(aclmdlSetExecConfigOpt); +} + ModelProcess::~ModelProcess() { if (dynamic_dims_ != nullptr) { delete[] dynamic_dims_; @@ -168,6 +171,25 @@ bool ModelProcess::PreInitModelResource() { MS_LOG(ERROR) << "Create output buffer failed."; return false; } + auto &stream_sync_timeout = options_->model_exec_config.stream_sync_timeout; + if (stream_sync_timeout.Value() != lite::kModelExecStreamSyncTimeoutIgnoreValue) { + if (CheckModelExecuteV2Support()) { + acl_ret = CALL_ASCEND_API(aclrtCreateStream, &stream_); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Create stream failed."; + return false; + } + exec_config_handle_ = CALL_ASCEND_API(aclmdlCreateExecConfigHandle); + acl_ret = CALL_ASCEND_API(aclmdlSetExecConfigOpt, exec_config_handle_, ACL_MDL_STREAM_SYNC_TIMEOUT, + &stream_sync_timeout.Value(), stream_sync_timeout.Size()); + if (acl_ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Set stream sync timeout failed."; + return false; + } + } else { + MS_LOG(WARNING) << "The current CANN version does not support specify stream_sync_timeout, please upgrade CANN."; + } + } if (is_dynamic_input_) { data_input_num_ = input_infos_.size(); return true; @@ -929,6 +951,22 @@ bool ModelProcess::UnLoad() { CALL_ASCEND_API(aclrtFreePhysical, shareable_phy_addr_); shareable_phy_addr_ = nullptr; } + if (stream_ != nullptr) { + ret = CALL_ASCEND_API(aclrtDestroyStream, stream_); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Destroy stream failed"; + return false; + } + stream_ = nullptr; + } + if (exec_config_handle_ != nullptr) { + ret = CALL_ASCEND_API(aclmdlDestroyExecConfigHandle, exec_config_handle_); + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Destroy exec config handle failed"; + return false; + } + exec_config_handle_ = nullptr; + } MS_LOG(INFO) << "End unload model " << model_id_; return true; } @@ -1383,6 +1421,20 @@ bool ModelProcess::ResetDynamicOutputTensor(const std::vector *outputs return true; } +Status ModelProcess::ExecuteModel(uint32_t model_id, aclmdlDataset *input, aclmdlDataset *output) { + aclError ret = ACL_SUCCESS; + if (stream_ && exec_config_handle_) { + ret = CALL_ASCEND_API(aclmdlExecuteV2, model_id, input, output, stream_, exec_config_handle_); + } else { + ret = CALL_ASCEND_API(aclmdlExecute, model_id, input, output); + } + if (ret != ACL_SUCCESS) { + MS_LOG(ERROR) << "Execute Model Failed, ret = " << ret << ", detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); + return kLiteError; + } + return kSuccess; +} + bool ModelProcess::PredictFromHost(const std::vector &inputs, const std::vector *outputs) { if (!loaded_) { MS_LOG(ERROR) << "Model has not been loaded"; @@ -1397,7 +1449,6 @@ bool ModelProcess::PredictFromHost(const std::vector &inputs, const st return false; } - aclError acl_ret; struct timeval start_time; auto env = std::getenv("GLOG_v"); bool output_timecost = (env != nullptr && (env[0] == kINFOLogLevel || env[0] == kDEBUGLogLevel)); @@ -1409,7 +1460,7 @@ bool ModelProcess::PredictFromHost(const std::vector &inputs, const st MS_LOG(DEBUG) << "Need to lock before aclmdlExecute."; AclMemManager::GetInstance().Lock(options_->device_id); } - acl_ret = CALL_ASCEND_API(aclmdlExecute, infer_id_, inputs_, outputs_); + auto model_ret = ExecuteModel(infer_id_, inputs_, outputs_); if (is_sharing_workspace_) { MS_LOG(DEBUG) << "Unlock after aclmdlExecute."; AclMemManager::GetInstance().Unlock(options_->device_id); @@ -1424,8 +1475,8 @@ bool ModelProcess::PredictFromHost(const std::vector &inputs, const st MS_LOG(INFO) << "Model execute in " << cost << " us"; } - if (acl_ret != ACL_SUCCESS) { - MS_LOG(ERROR) << "Execute Model Failed, ret = " << acl_ret << ", detail:" << CALL_ASCEND_API(aclGetRecentErrMsg); + if (model_ret != kSuccess) { + MS_LOG(ERROR) << "Execute Model Failed"; return false; } if (is_dynamic_output_) { diff --git a/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h index a55655e5..d51999d8 100644 --- a/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h +++ b/mindspore-lite/src/extendrt/delegate/ascend_acl/model_process.h @@ -118,6 +118,7 @@ class ModelProcess { bool ShareMemProcess(const void *om_data, size_t om_data_size); MSTensor GetOutputWithZeroCopy(const std::vector *outputs, size_t index); MSTensor CreateOutputTensor(size_t index); + Status ExecuteModel(uint32_t model_id, aclmdlDataset *inputs, aclmdlDataset *outputs); std::shared_ptr options_; uint32_t model_id_ = UINT32_MAX; @@ -129,6 +130,8 @@ class ModelProcess { aclmdlDataset *weight_inputs_ = nullptr; aclmdlDataset *weight_outputs_ = nullptr; aclmdlDesc *model_weight_desc_ = nullptr; + aclrtStream stream_ = nullptr; + aclmdlExecConfigHandle *exec_config_handle_ = nullptr; bool loaded_ = false; bool inited_weights_ = false; diff --git a/mindspore-lite/test/st/python/python_api/conftest.py b/mindspore-lite/test/st/python/python_api/conftest.py new file mode 100644 index 00000000..4a942d21 --- /dev/null +++ b/mindspore-lite/test/st/python/python_api/conftest.py @@ -0,0 +1,100 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Pytest ST configuration +""" +import pytest + + +def pytest_addoption(parser): + """ + pytest extra options + """ + parser.addoption( + "--backend", + action="store", + nargs="+", + default=[], + choices=("arm_ascend310_cloud", "arm_ascend310_ge_cloud", "mslite_large_model_inference_arm_ascend910B"), + help="Only test specified backend testcases. Example: --backend arm_ascend310_cloud arm_ascend310_ge_cloud", + ) + + parser.addoption( + "--device_id", + action="store", + nargs="+", + default=[0], + type=int, + help="Available device ids for test, default is [0]. Example: --device_id 0 1", + ) + + +@pytest.fixture +def device_id(request): + """ + device_id fixture + """ + return list(set(request.config.getoption("device_id"))) + + +def _parse_backend_mark(item, device_id_option): + """ + parse backend mark from item + """ + marker = item.get_closest_marker("backend") + if not marker: + return None + supported_backends = marker.args + require_device_num = marker.kwargs.get("require_device_num", 1) + if not supported_backends: + raise ValueError(f"item {item} marked with backend but no backend specified.") + if not isinstance(require_device_num, int): + raise TypeError( + f"item {item} marked with require_device_num={require_device_num}," + f" but require_device_num should be int. got {type(require_device_num)}" + ) + if require_device_num < 0: + raise ValueError( + f"item {item} marked with require_device_num={require_device_num}," + f" but require_device_num should be non-negative." + ) + if len(device_id_option) < require_device_num: + raise ValueError( + f"item {item} marked with require_device_num={require_device_num}," + f" but got device_id_option={device_id_option}, not enough devices." + ) + return supported_backends + + +def pytest_collection_modifyitems(config, items): + """ + pytest collection modifyitems hook. + 1. if a backend option is appeared, this hook will filter items that are marked with target backend + """ + backend_types = config.getoption("backend") + device_id_option = list(set(config.getoption("device_id"))) + if not backend_types: + return + + selected = [] + for item in items: + supported_backends = _parse_backend_mark(item, device_id_option) + if not supported_backends: + continue + if "all" in supported_backends or any(backend in supported_backends for backend in backend_types): + selected.append(item) + + config.hook.pytest_deselected(items=list(filter(lambda i: i not in selected, items))) + items[:] = selected diff --git a/mindspore-lite/test/st/python/python_api/pytest.ini b/mindspore-lite/test/st/python/python_api/pytest.ini new file mode 100644 index 00000000..3ee9fe3a --- /dev/null +++ b/mindspore-lite/test/st/python/python_api/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + backend(target, ..., *, require_device_num=1): testcase can run on target backend. Example: @pytest.mark.backend("arm_ascend310_cloud", "arm_ascend310_ge_cloud") diff --git a/mindspore-lite/test/st/python/python_api/test_stream_sync_timeout.py b/mindspore-lite/test/st/python/python_api/test_stream_sync_timeout.py new file mode 100644 index 00000000..30548e6e --- /dev/null +++ b/mindspore-lite/test/st/python/python_api/test_stream_sync_timeout.py @@ -0,0 +1,120 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test lite python API. +""" +import os +from collections import namedtuple +import pytest +import mindspore_lite as mslite +import numpy as np +from utils import ScopeTimeRecord, expect_error + + +MODEL_BASE_PATH = "." + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_LIMITED = { + "ascend_context": {"timeout": "3"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_UNLIMITED = { + "ascend_context": {"timeout": "-1"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_1 = { + "ascend_context": {"timeout": "0"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_2 = { + "ascend_context": {"timeout": "-2"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_3 = { + "ascend_context": {"timeout": "-2147483648"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_4 = { + "ascend_context": {"timeout": "-2147483649"}, +} + +STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_5 = { + "ascend_context": {"timeout": "2147483648"}, +} + +ConfigAndWillError = namedtuple("ConfigAndWillError", ["config", "build_error", "infer_error"]) + + +def _create_context(device_id): + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = device_id + return context + + +def _run_case_with_config(model_path, device_id, model_input, model_output, config_dict, build_error, infer_error): + """ + run case with a config + """ + print(f"run case with config: {config_dict}, expect: {build_error} {infer_error}", flush=True) + context = _create_context(device_id) + model = mslite.Model() + + with expect_error(build_error): + model.build_from_file(model_path, mslite.ModelType.MINDIR, context, config_dict=config_dict) + + with expect_error(infer_error): + with ScopeTimeRecord() as record: + model.predict(model_input, model_output) + + print(f"model predict time with config {config_dict}: {record.duration} ms", flush=True) + + +@pytest.mark.parametrize( + "model_name, inputs, outputs", + ( + ( + "sd1.5_unet.onnx_graph.mindir", + ( + np.ones((2, 4, 64, 64)).astype(np.float32), + np.ones((1,)).astype(np.float32), + np.ones((2, 77, 768)).astype(np.float32), + ), + (np.ones((2, 4, 64, 64)).astype(np.float32),), + ), + ), +) +@pytest.mark.backend("mslite_large_model_inference_arm_ascend910B") +def test_stream_sync_timeout(model_name, inputs, outputs, device_id): + """ + test config stream_sync_timeout + """ + model_path = os.path.join(MODEL_BASE_PATH, model_name) + + model_input = [mslite.Tensor(tensor=i, device=f"ascend:{device_id[0]}") for i in inputs] + model_output = [mslite.Tensor(tensor=o, device=f"ascend:{device_id[0]}") for o in outputs] + + # None * 3 for warm up + case_list = [ConfigAndWillError(None, None, None)] * 3 + [ + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_UNLIMITED, None, None), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_LIMITED, None, RuntimeError), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_1, RuntimeError, RuntimeError), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_2, RuntimeError, RuntimeError), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_3, RuntimeError, RuntimeError), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_4, RuntimeError, RuntimeError), + ConfigAndWillError(STREAM_SYNC_TIMEOUT_CONFIG_DICT_INVALID_5, RuntimeError, RuntimeError), + ] + + for case in case_list: + _run_case_with_config(model_path, device_id[0], model_input, model_output, *case) diff --git a/mindspore-lite/test/st/python/python_api/utils.py b/mindspore-lite/test/st/python/python_api/utils.py new file mode 100644 index 00000000..d4ae7739 --- /dev/null +++ b/mindspore-lite/test/st/python/python_api/utils.py @@ -0,0 +1,42 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +utils for test +""" +import time +import contextlib +import pytest + + +class ScopeTimeRecord: + """ + time record + """ + + def __enter__(self): + self.start_time = time.perf_counter() + return self + + def __exit__(self, *args): + self.finish_time = time.perf_counter() + self.duration = (self.finish_time - self.start_time) * 1000 + +@contextlib.contextmanager +def expect_error(errors, *args, **kwargs): + if errors: + with pytest.raises(errors, *args, **kwargs) as exc_info: + yield exc_info + else: + yield diff --git a/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh b/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh index e0e199d9..b9903ab3 100644 --- a/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh +++ b/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh @@ -400,6 +400,7 @@ if [[ "${MSLITE_ENABLE_COVERAGE}" == "on" || "${MSLITE_ENABLE_COVERAGE}" == "ON" python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_update_weight.py || exit 1 python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_acl_profiling.py || exit 1 python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_encrypt_and_decrypt.py || exit 1 + python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_stream_sync_timeout.py -c pytest.ini --device_id ${device_id} || exit 1 else pytest test_tensor.py || exit 1 pytest test_model.py || exit 1 @@ -408,6 +409,7 @@ else pytest test_update_weight.py || exit 1 pytest test_acl_profiling.py || exit 1 pytest test_encrypt_and_decrypt.py || exit 1 + pytest test_stream_sync_timeout.py -c pytest.ini --device_id ${device_id} || exit 1 fi echo "---------- Run MindSpore Lite API SUCCESS ----------" #--------------------------------------------------------- -- Gitee