From 3b08715f644cc2a82851e8cc71058ce021b489dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=86=8A=E6=94=80?= Date: Tue, 14 Oct 2025 17:00:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9E=84=E5=BB=BA=E9=9D=99=E6=80=81=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E5=9B=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ascend/single_matmul_model.onnx.config | 3 + ...le_matmul_model.onnx.variable_weights_file | 1 + .../config_level0/models_cloud_ascend_a2.cfg | 3 +- .../python/python_api/test_update_weight.py | 96 +++ .../st/scripts/ascend/run_cloud_arm_a2.sh | 4 + .../cxx_api/model/acl/acl_model_options.h | 3 - .../cxx_api/model/acl/model_converter.cc | 25 +- .../cxx_api/model/acl/model_converter.h | 9 + .../adapter/acl/src/acl_pass_impl.cc | 3 +- .../tools/converter/cxx_api/converter_para.h | 4 +- .../tools/optimizer/common/gllo_utils.cc | 56 ++ .../tools/optimizer/common/gllo_utils.h | 2 + .../optimizer/fusion/graph_split_pass.cc | 63 +- .../optimizer/graph/add_variable_node_pass.cc | 624 ++++-------------- .../optimizer/graph/add_variable_node_pass.h | 37 +- 15 files changed, 318 insertions(+), 615 deletions(-) create mode 100644 mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.config create mode 100644 mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.variable_weights_file create mode 100644 mindspore-lite/test/st/python/python_api/test_update_weight.py diff --git a/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.config b/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.config new file mode 100644 index 00000000..50a88b34 --- /dev/null +++ b/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.config @@ -0,0 +1,3 @@ +# root path is mindspore-lite/test/st/scripts/ascend so use this relate path to find the conifg +[ascend_context] +variable_weights_file=../../../config_level0/ascend/single_matmul_model.onnx.variable_weights_file \ No newline at end of file diff --git a/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.variable_weights_file b/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.variable_weights_file new file mode 100644 index 00000000..081fbc15 --- /dev/null +++ b/mindspore-lite/test/config_level0/ascend/single_matmul_model.onnx.variable_weights_file @@ -0,0 +1 @@ +input_matrix:4,4;node_matmul \ No newline at end of file diff --git a/mindspore-lite/test/config_level0/models_cloud_ascend_a2.cfg b/mindspore-lite/test/config_level0/models_cloud_ascend_a2.cfg index 1172048c..75488ee1 100644 --- a/mindspore-lite/test/config_level0/models_cloud_ascend_a2.cfg +++ b/mindspore-lite/test/config_level0/models_cloud_ascend_a2.cfg @@ -13,4 +13,5 @@ single_op_fa.onnx;3:q,k,v;1,32,1024,64:1,32,1024,64:1,32,1024,64;static;; 1 single_op_gns.onnx;3:x,gamma,beta;1,192,160,160:192:192;static;; 1 long_sequence_eta.pb;2:id,wt;1,576:1,576;static;; 1 vod_sr_M5_H10_20231107_manualv4_int8_dynamic.onnx;1:input1;1,3,1080,1920;static;; 2 -user_latent_vector.pb;3:embedding-feature,dense,emb;1,10688:1,1:1,128;static;; 1 \ No newline at end of file +user_latent_vector.pb;3:embedding-feature,dense,emb;1,10688:1,1:1,128;static;; 1 +single_matmul_model.onnx;1:input;1,4;static;; 1 \ No newline at end of file diff --git a/mindspore-lite/test/st/python/python_api/test_update_weight.py b/mindspore-lite/test/st/python/python_api/test_update_weight.py new file mode 100644 index 00000000..19cc0520 --- /dev/null +++ b/mindspore-lite/test/st/python/python_api/test_update_weight.py @@ -0,0 +1,96 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Test for MindSpore Lite update_weights +""" + +import pytest +import mindspore_lite as mslite +import numpy as np + +MODEL_FILE = "./single_matmul_model.onnx.mindir" +DEVICE_ID = 0 + +def test_update_weight_resul_change(): + model = mslite.Model() + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = DEVICE_ID + model.build_from_file(model_path=MODEL_FILE, model_type=mslite.ModelType.MINDIR, context=context) + np_input = np.ones((1, 4), dtype=np.float32) + ms_inputs = model.get_inputs() + ms_inputs[0].set_data_from_numpy(np_input) + outputs_nolora = model.predict(ms_inputs)[0].get_data_to_numpy() + weight = np.ones((4, 4), dtype=np.float32) + tensor = mslite.Tensor(weight) + model.update_weights([[tensor]]) + outputs_lora = model.predict(ms_inputs)[0].get_data_to_numpy() + assert not np.allclose(outputs_nolora, outputs_lora) + +def test_update_weight_multiple_times(): + try: + model = mslite.Model() + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = DEVICE_ID + model.build_from_file(model_path=MODEL_FILE, model_type=mslite.ModelType.MINDIR, context=context) + weight = np.ones((4, 4), dtype=np.float32) + tensor = mslite.Tensor(weight) + for i in range(5): + model.update_weights([[tensor]]) + except: + raise "test update weight multiple times failed!" + +def test_update_weight_zero_copy(): + model = mslite.Model() + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = DEVICE_ID + model.build_from_file(model_path=MODEL_FILE, model_type=mslite.ModelType.MINDIR, context=context) + np_input = np.ones((1, 4), dtype=np.float32) + ms_inputs = model.get_inputs() + ms_inputs[0].set_data_from_numpy(np_input) + weight = np.ones((4, 4), dtype=np.float32) + tensor = mslite.Tensor(tensor=weight, device="ascend:"+str(DEVICE_ID)) + model.update_weights([[tensor]]) + outputs_lora = model.predict(ms_inputs)[0].get_data_to_numpy() + lora_out = np.ones((1, 4), dtype=np.float32) @ np.ones((4, 4), dtype=np.float32) + assert np.mean(lora_out-outputs_lora) < 1e-5 + +def test_update_weight_precision(): + model = mslite.Model() + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = DEVICE_ID + model.build_from_file(model_path=MODEL_FILE, model_type=mslite.ModelType.MINDIR, context=context) + np_input = np.ones((1, 4), dtype=np.float32) + ms_inputs = model.get_inputs() + ms_inputs[0].set_data_from_numpy(np_input) + weight = np.ones((4, 4), dtype=np.float32) + tensor = mslite.Tensor(weight) + model.update_weights([[tensor]]) + outputs_lora = model.predict(ms_inputs)[0].get_data_to_numpy() + lora_out = np.ones((1, 4), dtype=np.float32) @ np.ones((4, 4), dtype=np.float32) + assert np.mean(lora_out-outputs_lora) < 1e-5 + +def test_update_weight_empty_weight(): + model = mslite.Model() + context = mslite.Context() + context.target = ["ascend"] + context.ascend.device_id = DEVICE_ID + model.build_from_file(model_path=MODEL_FILE, model_type=mslite.ModelType.MINDIR, context=context) + with pytest.raises(RuntimeError) as e: + model.update_weights([[]]) + assert "update weight failed! Error is Common error code" in str(e.value) diff --git a/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh b/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh index 76bb5170..81975e67 100644 --- a/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh +++ b/mindspore-lite/test/st/scripts/ascend/run_cloud_arm_a2.sh @@ -387,6 +387,7 @@ fi echo "---------- Run MindSpore Lite API ----------" cd ${basepath}/python/python_api/ || exit 1 cp -r ${ms_models_path}/sd1.5_unet.onnx* . || exit 1 # for Model Predict ST +cp -r ${ms_models_path}/single_matmul_model.onnx.mindir . || exit 1 # for Update weights ST #for code coverage in A2 if [[ "${MSLITE_ENABLE_COVERAGE}" == "on" || "${MSLITE_ENABLE_COVERAGE}" == "ON" ]]; then echo "MSLITE_ENABLE_COVERAGE: ${MSLITE_ENABLE_COVERAGE}, MSLITE_COVERAGE_FILE: ${MSLITE_COVERAGE_FILE}" @@ -394,13 +395,16 @@ if [[ "${MSLITE_ENABLE_COVERAGE}" == "on" || "${MSLITE_ENABLE_COVERAGE}" == "ON" python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_model.py || exit 1 python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_model_parallel_runner.py || exit 1 python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_model_info.py || exit 1 + python3 -m coverage run --rcfile=${MSLITE_COVERAGE_FILE} -m pytest test_update_weight.py || exit 1 else pytest test_tensor.py || exit 1 pytest test_model.py || exit 1 pytest test_model_parallel_runner.py || exit 1 pytest test_model_info.py || exit 1 + pytest test_update_weight.py || exit 1 fi echo "---------- Run MindSpore Lite API SUCCESS ----------" #--------------------------------------------------------- + echo "success" exit 0 diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_options.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_options.h index 3730f4df..a3b67cdb 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_options.h +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/acl_model_options.h @@ -57,8 +57,6 @@ class MS_API AclModelOptions { } std::map GetAoeGlobalOptionsMap() const { return aoe_global_options_map_; } static std::string GetSocName(); - std::vector GetConstName() const { return const_names_; } - void SetConstName(const std::vector &const_names) { const_names_ = const_names; } bool IsLastModel() { return is_last_model_; } void SetLastModel() { is_last_model_ = true; } @@ -87,7 +85,6 @@ class MS_API AclModelOptions { std::string om_file_path_; std::string aoe_mode_; std::string dump_model_name_; - std::vector const_names_; bool is_last_model_ = false; }; } // namespace mindspore diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc index 235bea7a..ac8f5e6b 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "backend/ge_backend/graph_ir/utils.h" #include "graph/graph_buffer.h" #include "graph/graph.h" @@ -28,6 +29,8 @@ #include "plugin/ascend/res_manager/symbol_interface/symbol_utils.h" #include "src/common/file_utils.h" #include "cxx_api/graph/acl/acl_convert_init_adapter.h" +#include "mindspore/ops/infer/custom.h" +#include "mindspore/core/include/ir/func_graph.h" namespace mindspore { namespace { @@ -113,16 +116,12 @@ Buffer ModelConverter::BuildAirModel(const backend::ge_backend::DfGraphPtr &grap #ifdef ENABLE_BUNDLE ge::WeightRefreshableGraphs split_graphs; - std::vector ascend_const_names; - std::vector const_names; - if (option != nullptr && !option->GetConstName().empty()) { - const_names = option->GetConstName(); - } - if (const_names.size() > 0) { - ascend_const_names.resize(const_names.size()); - std::transform(const_names.begin(), const_names.end(), ascend_const_names.begin(), + std::vector ascend_variable_names; + if (variable_node_names_.size() > 0 && update_func_graph_ != nullptr) { + ascend_variable_names.resize(variable_node_names_.size()); + std::transform(variable_node_names_.begin(), variable_node_names_.end(), ascend_variable_names.begin(), [](std::string s) { return ge::AscendString(s.c_str()); }); - auto ret = ge::aclgrphConvertToWeightRefreshableGraphs(*graph, ascend_const_names, split_graphs); + auto ret = ge::aclgrphConvertToWeightRefreshableGraphs(*graph, ascend_variable_names, split_graphs); if (ret != 0) { MS_LOG(ERROR) << "aclgraphConvertToWeightRefreshableGraphs failed! ret:" << ret; ge::aclgrphBuildFinalize(); @@ -138,9 +137,15 @@ Buffer ModelConverter::BuildAirModel(const backend::ge_backend::DfGraphPtr &grap update_options.insert(std::make_pair(ge::AscendString(it.first.c_str()), ge::AscendString(it.second.c_str()))); } } + auto update_graph = ConvertFuncGraphToAIR(update_func_graph_); + if (update_graph == nullptr) { + MS_LOG(ERROR) << "Convert FuncGraph to AscendIR failed."; + return Buffer(); + } + std::vector graph_and_options; graph_and_options.push_back(ge::GraphWithOptions{split_graphs.infer_graph, bund_bundle_options}); - graph_and_options.push_back(ge::GraphWithOptions{split_graphs.var_update_graph, update_options}); + graph_and_options.push_back(ge::GraphWithOptions{*update_graph, update_options}); ret = ge::aclgrphBundleBuildModel(graph_and_options, model); if (ret != ge::SUCCESS) { MS_LOG(ERROR) << "Call aclgrphBuildModel fail: " << CALL_ASCEND_API(aclGetRecentErrMsg); diff --git a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.h b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.h index 91245774..af8f40a1 100644 --- a/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.h +++ b/mindspore-lite/tools/converter/adapter/acl/cxx_api_lite/cxx_api/model/acl/model_converter.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "include/cxx_api/types.h" #include "include/cxx_api/status.h" #include "ir/func_graph.h" @@ -35,6 +36,10 @@ class MS_API ModelConverter { Buffer LoadMindIR(const FuncGraphPtr &func_graph); void set_options(const std::weak_ptr &options) { options_ = options; } + void set_update_graph(const FuncGraphPtr &update_graph) { update_func_graph_ = update_graph; } + void set_variable_node_names(const std::vector variable_node_names) { + variable_node_names_ = variable_node_names; + } private: backend::ge_backend::DfGraphPtr ConvertFuncGraphToAIR(const FuncGraphPtr &anf_graph) const; @@ -43,8 +48,12 @@ class MS_API ModelConverter { const std::map &build_options) const; Buffer LoadAscendIRInner(const Buffer &model_data); Status SaveModel(const ge::ModelBufferData &model) const; + Status CreateUpdateGraph(const std::vector &const_names, const std::vector &abstarcts, + backend::ge_backend::DfGraphPtr *df_graph) const; std::weak_ptr options_; + std::vector variable_node_names_; + FuncGraphPtr update_func_graph_ = nullptr; }; } // namespace mindspore #endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_CONVERTER_H diff --git a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc index 1f5a5eb5..e5fc5c9d 100644 --- a/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc +++ b/mindspore-lite/tools/converter/adapter/acl/src/acl_pass_impl.cc @@ -994,7 +994,8 @@ STATUS AclPassImpl::ConvertGraphToOm(const FuncGraphPtr &func_graph, Buffer *om_ } // call interface of cloud ModelConverter model_converter; - options_->SetConstName(param_->const_names); + model_converter.set_update_graph(param_->update_graph); + model_converter.set_variable_node_names(param_->variable_node_names); model_converter.set_options(options_); *om_data = model_converter.LoadMindIR(func_graph); if (om_data->Data() == nullptr || om_data->DataSize() == 0) { diff --git a/mindspore-lite/tools/converter/cxx_api/converter_para.h b/mindspore-lite/tools/converter/cxx_api/converter_para.h index 915b1109..5ac25483 100644 --- a/mindspore-lite/tools/converter/cxx_api/converter_para.h +++ b/mindspore-lite/tools/converter/cxx_api/converter_para.h @@ -116,7 +116,9 @@ struct ConverterPara { SplitGraphCfg splitGraphCfg; // configs parse from config_file ConfigInfos config_infos; - std::vector const_names; + std::vector variable_node_names; + std::vector variable_node_abstracts; + FuncGraphPtr update_graph = nullptr; }; } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_CXX_API_CONVERTER_PARA_H_ diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.cc b/mindspore-lite/tools/optimizer/common/gllo_utils.cc index 5b433bba..0e0220f6 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.cc @@ -59,6 +59,8 @@ #include "mindspore/core/include/ir/func_graph_flag.h" #include "ir/tensor_new.h" #include "mindspore/core/include/ir/graph_utils.h" +#include "mindspore/ops/infer/return.h" +#include "mindspore/ops/infer/make_tuple.h" namespace mindspore { namespace opt { @@ -2057,5 +2059,59 @@ STATUS GetPrimFromCnode(const CNodePtr &cnode, PrimitivePtr *prim_ptr) { return lite::RET_OK; } +Status BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector &return_inputs) { + MS_CHECK_TRUE_RET(anf_graph != nullptr, kLiteNullptr); + auto return_prim = std::make_shared(); + if (return_prim == nullptr) { + MS_LOG(ERROR) << "new return failed!"; + return kLiteNullptr; + } + if (return_inputs.empty()) { + MS_LOG(ERROR) << "return input is empty"; + return kLiteError; + } + auto final_return = return_inputs; + AbstractBasePtr abstract = nullptr; + if (return_inputs.size() == 1) { + anf_graph->set_output(return_inputs.front(), false); + abstract = return_inputs.front()->abstract(); + MS_CHECK_TRUE_MSG(abstract != nullptr, kLiteNullptr, "abstract is nullptr!"); + } else if (return_inputs.size() > 1) { + auto make_tuple_prim_ptr = std::make_shared(); + if (make_tuple_prim_ptr == nullptr) { + MS_LOG(DEBUG) << "new maketyple failed"; + return kLiteNullptr; + } + AbstractBasePtrList elem; + std::transform(return_inputs.begin(), return_inputs.end(), std::back_inserter(elem), + [](auto &node) { return node->abstract(); }); + auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); + MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, kLiteNullptr, "make_tuple_prim_c is nullptr!"); + auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_c, return_inputs); + if (make_tuple_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode failed!"; + return kLiteNullptr; + } + make_tuple_cnode->set_fullname_with_scope("return tuple"); + abstract = std::make_shared(elem); + MS_CHECK_TRUE_MSG(abstract != nullptr, kLiteNullptr, "abstract is nullptr!"); + make_tuple_cnode->set_abstract(abstract); + final_return = {make_tuple_cnode}; + } else { + MS_LOG(ERROR) << "Return inputs is 0!"; + return kLiteError; + } + auto return_prim_c = return_prim->GetPrim(); + MS_CHECK_TRUE_MSG(return_prim_c != nullptr, kLiteNullptr, "return_prim_c is nullptr!"); + auto return_cnode = anf_graph->NewCNode(return_prim_c, final_return); + if (return_cnode == nullptr) { + MS_LOG(ERROR) << "new cnode error"; + return kLiteError; + } + return_cnode->set_fullname_with_scope("Return"); + return_cnode->set_abstract(abstract); + anf_graph->set_return(return_cnode); + return kSuccess; +} }; // namespace opt } // namespace mindspore diff --git a/mindspore-lite/tools/optimizer/common/gllo_utils.h b/mindspore-lite/tools/optimizer/common/gllo_utils.h index c75381cd..41bce601 100644 --- a/mindspore-lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore-lite/tools/optimizer/common/gllo_utils.h @@ -32,6 +32,7 @@ #include "infer/cxx_api/conv2d_backprop_input_fusion.h" #include "schema/inner/model_generated.h" #include "tools/converter/converter_context.h" +#include "include/cxx_api/status.h" using PrimitiveCPtr = std::shared_ptr; using mindspore::lite::RET_ERROR; @@ -227,6 +228,7 @@ inline bool IsSpecifiedNode(const BaseRef &n) { return false; } +Status BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector &return_inputs); tensor::TensorPtr GetTensorFromParameterNode(const EquivPtr &equiv, const VarPtr &input); const float GetFloatParameterValue(const EquivPtr &equiv, const VarPtr &input); const int GetIntParameterValue(const EquivPtr &equiv, const VarPtr &input); diff --git a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc index 6ac2cab1..e11e830a 100644 --- a/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc +++ b/mindspore-lite/tools/optimizer/fusion/graph_split_pass.cc @@ -21,73 +21,14 @@ #include #include #include -#include "infer/return.h" -#include "tools/converter/export_model.h" -#include "infer/make_tuple.h" #include "mindspore/core/include/ir/graph_utils.h" +#include "tools/optimizer/common/gllo_utils.h" namespace mindspore::opt { namespace { constexpr size_t kTargetNodeSize = 2; } -STATUS BuildReturnNode(const FuncGraphPtr &anf_graph, const std::vector &return_inputs) { - MS_CHECK_TRUE_RET(anf_graph != nullptr, lite::RET_NULL_PTR); - auto return_prim = std::make_shared(); - if (return_prim == nullptr) { - MS_LOG(ERROR) << "new return failed!"; - return lite::RET_NULL_PTR; - } - if (return_inputs.empty()) { - MS_LOG(ERROR) << "return input is empty"; - return lite::RET_ERROR; - } - auto final_return = return_inputs; - AbstractBasePtr abstract = nullptr; - if (return_inputs.size() == 1) { - anf_graph->set_output(return_inputs.front(), false); - abstract = return_inputs.front()->abstract(); - } else if (return_inputs.size() > 1) { - auto make_tuple_prim_ptr = std::make_shared(); - if (make_tuple_prim_ptr == nullptr) { - MS_LOG(DEBUG) << "new maketyple failed"; - return lite::RET_NULL_PTR; - } - AbstractBasePtrList elem; - std::transform(return_inputs.begin(), return_inputs.end(), std::back_inserter(elem), - [](auto &node) { return node->abstract(); }); - auto make_tuple_prim_c = make_tuple_prim_ptr->GetPrim(); - MS_CHECK_TRUE_MSG(make_tuple_prim_c != nullptr, lite::RET_NULL_PTR, "make_tuple_prim_c is nullptr!"); - auto make_tuple_cnode = anf_graph->NewCNode(make_tuple_prim_c, return_inputs); - if (make_tuple_cnode == nullptr) { - MS_LOG(ERROR) << "new cnode failed!"; - return lite::RET_NULL_PTR; - } - make_tuple_cnode->set_fullname_with_scope("return tuple"); - make_tuple_cnode->set_abstract(std::make_shared(elem)); - abstract = make_tuple_cnode->abstract(); - final_return = {make_tuple_cnode}; - } else { - MS_LOG(ERROR) << "Return inputs is 0!"; - return lite::RET_ERROR; - } - if (abstract == nullptr) { - MS_LOG(ERROR) << "Input node abstract is null, node:" << final_return.front()->fullname_with_scope(); - return lite::RET_ERROR; - } - auto return_prim_c = return_prim->GetPrim(); - CHECK_NULL_RETURN(return_prim_c); - auto return_cnode = anf_graph->NewCNode(return_prim_c, final_return); - if (return_cnode == nullptr) { - MS_LOG(ERROR) << "new cnode error"; - return lite::RET_ERROR; - } - return_cnode->set_fullname_with_scope("Return"); - return_cnode->set_abstract(abstract); - anf_graph->set_return(return_cnode); - return lite::RET_OK; -} - bool IsWeight(const AnfNodePtr &node) { return (utils::isa(node) && node->cast() != nullptr && node->cast()->has_default()); @@ -677,7 +618,7 @@ bool GraphSplitPass::Run(const FuncGraphPtr &original_graph) { for (auto subgraph_output : subgraph_output_vec[i]) { subgraph_output_names[i].push_back(subgraph_output->fullname_with_scope()); } - if (BuildReturnNode(subgraphs[i], subgraph_output_vec[i]) != lite::RET_OK) { + if (BuildReturnNode(subgraphs[i], subgraph_output_vec[i]) != kSuccess) { MS_LOG(ERROR) << "build return node failed!"; return false; } diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc index ada4fe63..d2205dce 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "src/common/log_adapter.h" #include "src/common/log_util.h" #include "src/common/common.h" @@ -36,6 +37,7 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_c.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_m.h" #include "mindspore/core/include/ir/graph_utils.h" +#include "mindspore/ops/infer/custom.h" namespace mindspore { namespace opt { @@ -55,6 +57,12 @@ constexpr float kInitOne = 1.0; constexpr size_t kInitBatchSize = 1; constexpr size_t kMaxConfigLen = 1e6; constexpr uint16_t kFloatOne = 15360; + +bool MatchPattern(const std::string &input) { + std::regex pattern(R"(^([^:;]+):(\d+(?:,\d+)*);([^:;]+)$)"); + return std::regex_match(input, pattern); +} + } // namespace template @@ -105,9 +113,9 @@ TypeId FetchTypeIdByNode(const AnfNodePtr &node) { return type_id; } -lite::STATUS FetchWeightShape(AnfNodePtr weight, ShapeVector *weight_shape, const CNodePtr &cnode, bool is_matmul) { +lite::STATUS FetchWeightShape(AnfNodePtr weight, ShapeVector *weight_shape) { if (!utils::isa(weight)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight!"; + MS_LOG(ERROR) << "weight is not ParameterPtr!"; return RET_ERROR; } auto weight_param = weight->cast(); @@ -120,340 +128,10 @@ lite::STATUS FetchWeightShape(AnfNodePtr weight, ShapeVector *weight_shape, cons return RET_OK; } -lite::STATUS CreateBMMNode(AnfNodePtrList &&bmm_inputs, const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const std::string &suffix, AnfNodePtr *bmm_param) { - auto bmm = std::make_shared(); - auto bmm_prim_c = bmm->GetPrim(); - auto bmm_cnode = func_graph->NewCNode(bmm_prim_c, bmm_inputs); - if (bmm_cnode == nullptr) { - MS_LOG(ERROR) << "new bmm node failed!"; - return RET_ERROR; - } - bmm_cnode->set_fullname_with_scope(node->fullname_with_scope() + suffix); - if (!utils::isa(bmm_cnode)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight!"; - return RET_OK; - } - *bmm_param = bmm_cnode->cast(); - if (node->abstract() != nullptr) { - bmm_cnode->set_abstract(node->abstract()->Clone()); - } - return RET_OK; -} - -lite::STATUS CreateMulNode(AnfNodePtrList &&mul_inputs, const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const std::string &suffix, AnfNodePtr *mul_param) { - auto mul = std::make_shared(); - auto mul_prim_c = mul->GetPrim(); - auto mul_cnode = func_graph->NewCNode(mul_prim_c, mul_inputs); - if (mul_cnode == nullptr) { - MS_LOG(ERROR) << "new alpha mul node failed!"; - return false; - } - mul_cnode->set_fullname_with_scope(node->fullname_with_scope() + suffix); - if (!utils::isa(mul_cnode)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight!"; - return RET_ERROR; - } - *mul_param = mul_cnode->cast(); - if (node->abstract() != nullptr) { - mul_cnode->set_abstract(node->abstract()->Clone()); - } - return RET_OK; -} - -lite::STATUS CreateReduceSumNode(AnfNodePtrList &&reduce_sum_inputs, const FuncGraphPtr &func_graph, - const AnfNodePtr &node, const std::string &suffix, AnfNodePtr *reduce_sum_param) { - auto reduce_sum = std::make_shared(); - auto reduce_sum_prim_c = reduce_sum->GetPrim(); - auto reduce_sum_cnode = func_graph->NewCNode(reduce_sum_prim_c, reduce_sum_inputs); - if (reduce_sum_cnode == nullptr) { - MS_LOG(ERROR) << "new reduce sum node failed!"; - return RET_ERROR; - } - reduce_sum_cnode->set_fullname_with_scope(node->fullname_with_scope() + suffix); - if (!utils::isa(reduce_sum_cnode)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight!"; - return RET_ERROR; - } - *reduce_sum_param = reduce_sum_cnode->cast(); - if (node->abstract() != nullptr) { - reduce_sum_cnode->set_abstract(node->abstract()->Clone()); - } - return RET_OK; -} - -lite::STATUS CreateTransposeNode(AnfNodePtrList &&transpose_inputs, const FuncGraphPtr &func_graph, - const AnfNodePtr &node, const std::string &suffix, AnfNodePtr *transpose_param) { - auto transpose = std::make_shared(); - auto transpose_prim_c = transpose->GetPrim(); - auto transpose_cnode = func_graph->NewCNode(transpose_prim_c, transpose_inputs); - if (transpose_cnode == nullptr) { - MS_LOG(ERROR) << "new reduce sum node failed!"; - return false; - } - transpose_cnode->set_fullname_with_scope(node->fullname_with_scope() + suffix); - if (!utils::isa(transpose_cnode)) { - MS_LOG(ERROR) << "matmul weight is not constant, can not update weight!"; - return RET_ERROR; - } - *transpose_param = transpose_cnode->cast(); - if (node->abstract() != nullptr) { - transpose_cnode->set_abstract(node->abstract()->Clone()); - } - return RET_OK; -} - -lite::STATUS CreateAddNode(AnfNodePtrList &&add_inputs, const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const std::string &suffix, CNodePtr *add_cnode) { - auto add = std::make_shared(); - auto add_prim_c = add->GetPrim(); - (*add_cnode) = func_graph->NewCNode(add_prim_c, add_inputs); - if (*add_cnode == nullptr) { - MS_LOG(ERROR) << "new add node failed!"; - return RET_ERROR; - } - (*add_cnode)->set_fullname_with_scope(node->fullname_with_scope() + suffix); - if (node->abstract() != nullptr) { - (*add_cnode)->set_abstract(node->abstract()->Clone()); - } - return RET_OK; -} - -lite::STATUS FetchNodeNameMap(const CNodePtr &cnode, std::unordered_map *node_name_map, - const bool &has_alpha) { - auto node_name = cnode->fullname_with_scope(); - size_t last_slash_pos = node_name.find_last_of('/'); - MS_CHECK_TRUE_RET(last_slash_pos != std::string::npos, RET_ERROR); - auto search_key = node_name.substr(0, last_slash_pos); - (*node_name_map)[search_key + "variable_up"] = cnode->fullname_with_scope() + "_lora_up_const"; - (*node_name_map)[search_key + "variable_down"] = cnode->fullname_with_scope() + "_lora_down_const"; - if (has_alpha) { - (*node_name_map)[search_key + "variable_alpha"] = cnode->fullname_with_scope() + "_lora_alpha_const"; - } - return RET_OK; -} - -lite::STATUS InsertVariableNodePass::InsertVariableNodeForMatmul( - const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, const std::vector &up_shape, - std::unordered_map *node_name_map, bool has_alpha, int max_weight_batch) { - MS_CHECK_TRUE_RET(cnode->inputs().size() >= kInputSize3 && up_shape.size() == kConstantMatmulWeightShapeSize, - RET_ERROR); - auto weight = cnode->input(kInputIndex2); - MS_CHECK_TRUE_RET(weight != nullptr, RET_ERROR); - ShapeVector weight_shape; - auto ret = FetchWeightShape(weight, &weight_shape, cnode, true); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "fetch wieght shape failed! ret:" << ret << "!"; - return ret; - } - auto low_rank = MIN(up_shape[kIndex0], up_shape[kIndex1]); - auto value_node = cnode->input(kIndex0)->cast(); - MS_CHECK_TRUE_RET(value_node != nullptr, RET_ERROR); - auto src_prim = GetValueNode(value_node); - MS_CHECK_TRUE_RET(src_prim != nullptr, RET_ERROR); - int64_t up_high_rank = weight_shape[kIndex1]; - int64_t down_high_rank = weight_shape[kIndex0]; - bool is_gemm = false; - if (src_prim->GetAttr(mindspore::ops::kTransposeA) != nullptr || - src_prim->GetAttr(mindspore::ops::kTransposeB) != nullptr) { - up_high_rank = weight_shape[kIndex0]; - down_high_rank = weight_shape[kIndex1]; - is_gemm = true; - } - ShapeVector lora_up_shape = {max_weight_batch, up_high_rank, low_rank}; - ShapeVector lora_down_shape = {max_weight_batch, low_rank, down_high_rank}; - ShapeVector lora_add_shape = {max_weight_batch, kInitBatchSize, kInitBatchSize}; - ShapeVector lora_alpha_shape = {max_weight_batch, kInitBatchSize, kInitBatchSize}; - AnfNodePtr lora_up_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_up_shape, cnode->fullname_with_scope() + "_lora_up", 0.0, kNumberTypeFloat16); - AnfNodePtr lora_down_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_down_shape, cnode->fullname_with_scope() + "_lora_down", 0.0, kNumberTypeFloat16); - AnfNodePtr add_weights_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_add_shape, cnode->fullname_with_scope() + "_lora_add_weights", kInitOne, kNumberTypeFloat32); - AnfNodePtr alpha_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_alpha_shape, cnode->fullname_with_scope() + "_lora_alpha", kFloatOne, kNumberTypeFloat16); - AnfNodePtr axes_param_node = - opt::BuildIntValueParameterNode(func_graph, kInitZero, cnode->fullname_with_scope() + "_reduce_sum_axes", true); - MS_CHECK_TRUE_RET(lora_up_param_node != nullptr && lora_down_param_node != nullptr && - add_weights_param_node != nullptr && alpha_param_node != nullptr && axes_param_node != nullptr, - RET_ERROR); - if (FetchNodeNameMap(cnode, node_name_map, has_alpha) != RET_OK) { - MS_LOG(ERROR) << "FetchNodeNameMap failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto bmm_inputs = {lora_up_param_node, lora_down_param_node}; - AnfNodePtr bmm_param = nullptr; - if (CreateBMMNode(bmm_inputs, func_graph, node, "_lora_bmm", &bmm_param) != RET_OK) { - MS_LOG(ERROR) << "Create BMM node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto mul_alpha_inputs = {bmm_param, alpha_param_node}; - AnfNodePtr alpha_mul_param = nullptr; - if (CreateMulNode(mul_alpha_inputs, func_graph, node, "_lora_alpha_mul", &alpha_mul_param) != RET_OK) { - MS_LOG(ERROR) << "Create mul node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto mul_add_weights_inputs = {alpha_mul_param, add_weights_param_node}; - AnfNodePtr mul_add_weights_param = nullptr; - if (CreateMulNode(mul_add_weights_inputs, func_graph, node, "_lora_add_weights_mul", &mul_add_weights_param) != - RET_OK) { - MS_LOG(ERROR) << "Create mul node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto reduce_sum_inputs = {mul_add_weights_param, axes_param_node}; - AnfNodePtr reduce_sum_param = nullptr; - if (CreateReduceSumNode(reduce_sum_inputs, func_graph, node, "_lora_reduce_sum", &reduce_sum_param) != RET_OK) { - MS_LOG(ERROR) << "Create reducesum node failed! ret:" << ret << "!"; - return RET_ERROR; - } - std::vector perm = {kIndex1, kIndex0}; - if (is_gemm) { - perm = {kIndex0, kIndex1}; - } - AnfNodePtr perm_param_node = - opt::BuildIntVecParameterNode(func_graph, perm, cnode->fullname_with_scope() + "_trans_perm"); - auto transpose_inputs = {reduce_sum_param, perm_param_node}; - AnfNodePtr transpose_param = nullptr; - if (CreateTransposeNode(transpose_inputs, func_graph, node, "_lora_transpose", &transpose_param) != RET_OK) { - MS_LOG(ERROR) << "Create transpose node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto add_inputs = {transpose_param, weight}; - CNodePtr add_cnode = nullptr; - if (CreateAddNode(add_inputs, func_graph, node, "_lora_add", &add_cnode) != RET_OK) { - MS_LOG(ERROR) << "Create Add node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto manager = Manage(func_graph); - (void)manager->Replace(weight, add_cnode); - return RET_OK; -} - -lite::STATUS InsertVariableNodePass::InsertVariableNodeForConv( - const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, const std::vector &up_shape, - std::unordered_map *node_name_map, bool has_alpha, int max_weight_batch) { - MS_CHECK_TRUE_RET(cnode->inputs().size() >= kInputSize3, RET_ERROR); - MS_CHECK_TRUE_RET(up_shape.size() == kConstantConvWeightShapeSize, RET_ERROR); - auto weight = cnode->input(kInputIndex2); - MS_CHECK_TRUE_RET(weight != nullptr, RET_ERROR); - ShapeVector weight_shape; - auto ret = FetchWeightShape(weight, &weight_shape, cnode, false); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "fetch wieght shape failed! ret:" << ret << "!"; - return ret; - } - int kernel_size_down = weight_shape[kIndex2] / up_shape[kIndex2]; - ShapeVector lora_up_shape = {max_weight_batch, up_shape[kIndex0], up_shape[kIndex1], up_shape[kIndex2], - up_shape[kIndex3]}; - ShapeVector lora_down_shape = {max_weight_batch, up_shape[kIndex1], weight_shape[kIndex1], kernel_size_down, - kernel_size_down}; - ShapeVector add_weights_shape = {max_weight_batch, kIndex1, kIndex1, kIndex1, kIndex1}; - ShapeVector alpha_weights_shape = {max_weight_batch, kIndex1, kIndex1, kIndex1, kIndex1}; - AnfNodePtr lora_up_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_up_shape, cnode->fullname_with_scope() + "_lora_up", 0.0, kNumberTypeFloat16); - AnfNodePtr lora_down_param_node = BuildZeroVecNDParameterNode( - func_graph, lora_down_shape, cnode->fullname_with_scope() + "_lora_down", 0.0, kNumberTypeFloat16); - AnfNodePtr add_weights_param_node = BuildZeroVecNDParameterNode( - func_graph, add_weights_shape, cnode->fullname_with_scope() + "_lora_add", kInitOne, kNumberTypeFloat32); - AnfNodePtr alpha_param_node = BuildZeroVecNDParameterNode( - func_graph, alpha_weights_shape, cnode->fullname_with_scope() + "_lora_alpha", kFloatOne, kNumberTypeFloat16); - ret = FetchNodeNameMap(cnode, node_name_map, has_alpha); - if (ret != RET_OK) { - MS_LOG(ERROR) << "FetchNodeNameMap failed! ret:" << ret << "!"; - return ret; - } - AnfNodePtr axes_param_node = - opt::BuildIntValueParameterNode(func_graph, kInitZero, cnode->fullname_with_scope() + "_reduce_sum_axes", true); - std::vector perm = {kIndex3, kIndex2, kIndex1, kIndex0}; - AnfNodePtr perm_param_node = - opt::BuildIntVecParameterNode(func_graph, perm, cnode->fullname_with_scope() + "_trans_perm"); - std::vector perm_reverse = {kIndex0, kIndex4, kIndex3, kIndex2, kIndex1}; - AnfNodePtr perm_reverse_param_node = - opt::BuildIntVecParameterNode(func_graph, perm_reverse, cnode->fullname_with_scope() + "_trans_reverse_perm"); - MS_CHECK_TRUE_RET(lora_up_param_node != nullptr && lora_down_param_node != nullptr && - add_weights_param_node != nullptr && alpha_param_node != nullptr && perm_param_node != nullptr && - perm_reverse_param_node != nullptr, - RET_ERROR); - // transpose up - auto transpose_up_inputs = {lora_up_param_node, perm_reverse_param_node}; - AnfNodePtr transpose_up_param = nullptr; - if (CreateTransposeNode(transpose_up_inputs, func_graph, node, "_lora_up_transpose", &transpose_up_param) != RET_OK) { - MS_LOG(ERROR) << "Create transpose node failed! ret:" << ret << "!"; - return RET_ERROR; - } - // transpose down - auto transpose_down_inputs = {lora_down_param_node, perm_reverse_param_node}; - AnfNodePtr transpose_down_param = nullptr; - if (CreateTransposeNode(transpose_down_inputs, func_graph, node, "_lora_down_transpose", &transpose_down_param) != - RET_OK) { - MS_LOG(ERROR) << "Create transpose node failed! ret:" << ret << "!"; - return RET_ERROR; - } - // bmm - auto bmm_inputs = {transpose_down_param, transpose_up_param}; - AnfNodePtr bmm_param = nullptr; - if (CreateBMMNode(bmm_inputs, func_graph, node, "_lora_bmm", &bmm_param) != RET_OK) { - MS_LOG(ERROR) << "Create bmm node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto mul_alpha_inputs = {bmm_param, alpha_param_node}; - AnfNodePtr alpha_mul_param = nullptr; - if (CreateMulNode(mul_alpha_inputs, func_graph, node, "_lora_alpha_mul", &alpha_mul_param) != RET_OK) { - MS_LOG(ERROR) << "Create mul node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto mul_add_weights_inputs = {alpha_mul_param, add_weights_param_node}; - AnfNodePtr mul_add_weights_param = nullptr; - if (CreateMulNode(mul_add_weights_inputs, func_graph, node, "_lora_add_weights_mul", &mul_add_weights_param) != - RET_OK) { - MS_LOG(ERROR) << "Create mul node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto reduce_sum_inputs = {mul_add_weights_param, axes_param_node}; - AnfNodePtr reduce_sum_param = nullptr; - if (CreateReduceSumNode(reduce_sum_inputs, func_graph, node, "_lora_reduce_sum", &reduce_sum_param) != RET_OK) { - MS_LOG(ERROR) << "Create reducesum node failed! ret:" << ret << "!"; - return RET_ERROR; - } - // transpose - auto transpose_inputs = {reduce_sum_param, perm_param_node}; - AnfNodePtr transpose_param = nullptr; - if (CreateTransposeNode(transpose_inputs, func_graph, node, "_lora_transpose", &transpose_param) != RET_OK) { - MS_LOG(ERROR) << "Create transpose node failed! ret:" << ret << "!"; - return RET_ERROR; - } - // add - auto add_inputs = {transpose_param, weight}; - CNodePtr add_cnode = nullptr; - if (CreateAddNode(add_inputs, func_graph, node, "_lora_add", &add_cnode) != RET_OK) { - MS_LOG(ERROR) << "Create add node failed! ret:" << ret << "!"; - return RET_ERROR; - } - auto manager = Manage(func_graph); - (void)manager->Replace(weight, add_cnode); - return RET_OK; -} - -STATUS InsertVariableNodePass::ParseShapeStr(std::string shape_str, std::vector *shape) { - int shape_len = shape_str.size(); - std::string shape_nums = shape_str.substr(1, shape_len - 2); - std::stringstream ss(shape_nums); - std::string token; - while (std::getline(ss, token, ',')) { - shape->push_back(std::stoi(token)); - } - return RET_OK; -} - -lite::STATUS InsertVariableNodePass::ParseInsertNode(std::string file_path, - std::map> *variable_nodes, - std::unordered_map *node_name_map, - std::vector *node_name_list, bool *has_alpha) { +lite::STATUS InsertVariableNodePass::ParseInsertNode(std::string file_path, std::set *variable_nodes, + std::vector *node_name_list) { MS_CHECK_TRUE_RET(variable_nodes != nullptr, lite::RET_NULL_PTR); - MS_CHECK_TRUE_RET(node_name_map != nullptr, lite::RET_NULL_PTR); MS_CHECK_TRUE_RET(node_name_list != nullptr, lite::RET_NULL_PTR); - MS_CHECK_TRUE_RET(has_alpha != nullptr, lite::RET_NULL_PTR); std::ifstream file; auto ret = lite::ReadFileToIfstream(file_path, &file); if (ret != RET_OK) { @@ -463,6 +141,11 @@ lite::STATUS InsertVariableNodePass::ParseInsertNode(std::string file_path, size_t config_len = 0; std::string line; while (std::getline(file, line)) { + if (!MatchPattern(line)) { + MS_LOG(ERROR) << "Format of config error, it should be 'weight_name:num1,num2,num3;node_name', input config:" + << line; + return RET_ERROR; + } config_len++; if (config_len >= kMaxConfigLen) { MS_LOG(ERROR) << "Support max config len is " << kMaxConfigLen << ", current len:" << config_len << "!"; @@ -475,103 +158,16 @@ lite::STATUS InsertVariableNodePass::ParseInsertNode(std::string file_path, return RET_ERROR; } auto variable_para_name = line.substr(0, pos_colon); - if (variable_para_name.find("alpha") != std::string::npos && (*has_alpha) != true) { - (*has_alpha) = true; - } - auto pos_semicolon = line.find(';'); - if (pos_semicolon == std::string::npos) { - MS_LOG(ERROR) << "Parse variable weight file error!"; - file.close(); - return RET_ERROR; - } - auto weight_shape_str = line.substr(pos_colon + 1, pos_semicolon - pos_colon - 1); - auto node_name = line.substr(pos_semicolon + 1); - std::string record_name = ""; - if (variable_para_name.find(".up.") != std::string::npos || - variable_para_name.find("lora_up") != std::string::npos) { - record_name = node_name + "variable_up"; - } else if (variable_para_name.find(".down.") != std::string::npos || - variable_para_name.find("lora_down") != std::string::npos) { - record_name = node_name + "variable_down"; - } else if (variable_para_name.find("alpha") != std::string::npos) { - record_name = node_name + "variable_alpha"; - } else { - MS_LOG(ERROR) << "Only support up weight, down weight and alpha!"; - return RET_ERROR; - } - if (node_name_map->find(record_name) == node_name_map->end()) { - (*node_name_map)[record_name] = ""; - (*node_name_list).push_back(record_name); - } - // Only Upsape is recorded, so that you can easily check node name - if (variable_para_name.find(".up.") == std::string::npos && - variable_para_name.find("lora_up") == std::string::npos) { - continue; - } - std::vector shape; - ret = ParseShapeStr(weight_shape_str, &shape); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ParseShapeStr error! ret:" << ret << "!"; - file.close(); - return ret; - } - variable_nodes->insert({node_name, shape}); + (*node_name_list).push_back(variable_para_name); + variable_nodes->insert(variable_para_name); } file.close(); return RET_OK; } -lite::STATUS InsertVariableNodePass::CheckOnlyReplace(CNodePtr cnode, const std::vector ¶_shape, - const bool &is_matmul, bool *compare_res) { - MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cnode is nullptr!"); - MS_CHECK_TRUE_MSG(compare_res != nullptr, RET_ERROR, "compare_res is nullptr!"); - auto weight = cnode->input(kInputIndex2); - MS_CHECK_TRUE_RET(weight != nullptr, RET_ERROR); - ShapeVector weight_shape; - auto ret = FetchWeightShape(weight, &weight_shape, cnode, is_matmul); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "fetch wieght shape failed! ret:" << ret << "!"; - return ret; - } - if (weight_shape.size() != para_shape.size()) { - *compare_res = false; - return RET_OK; - } - *compare_res = std::equal(weight_shape.begin(), weight_shape.end(), para_shape.begin()); - return RET_OK; -} - -lite::STATUS InsertVariableNodePass::RecordVariableName(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const string &search_key, bool is_matmul, - std::unordered_map *node_name_map) { - MS_CHECK_TRUE_RET(node_name_map != nullptr, RET_ERROR); - MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR); - if (cnode->inputs().size() < kInputSize3) { - MS_LOG(ERROR) << "Weight size must greater than 3, current size:" << cnode->inputs().size() << "!"; - return RET_ERROR; - } - auto weight = cnode->input(kInputIndex2); - MS_CHECK_TRUE_RET(weight != nullptr, RET_ERROR); - ShapeVector weight_shape; - auto ret = FetchWeightShape(weight, &weight_shape, cnode, is_matmul); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "fetch wieght shape failed! ret:" << ret << "!"; - return ret; - } - AnfNodePtr fp16_weight = BuildZeroVecNDParameterNode( - func_graph, weight_shape, weight->fullname_with_scope(), 0.0, kNumberTypeFloat16); - MS_CHECK_TRUE_MSG(fp16_weight != nullptr, RET_ERROR, "fp16_weight is nullptr!"); - (*node_name_map)[search_key + "variable_up"] = weight->fullname_with_scope() + "_const"; - auto manager = Manage(func_graph); - MS_CHECK_TRUE_RET(manager != nullptr, RET_ERROR); - (void)manager->Replace(weight, fp16_weight); - return RET_OK; -} - -void InsertVariableNodePass::InitWeightParam(const std::shared_ptr ¶m, - std::string *variable_weights_file, int32_t *max_weight_batch) { - if (param->config_infos.find(lite::kAscendContextSection) != param->config_infos.end()) { - auto ascend_context = param->config_infos.at(lite::kAscendContextSection); +void InsertVariableNodePass::InitWeightParam(std::string *variable_weights_file, int32_t *max_weight_batch) { + if (param_->config_infos.find(lite::kAscendContextSection) != param_->config_infos.end()) { + auto ascend_context = param_->config_infos.at(lite::kAscendContextSection); if (ascend_context.find(lite::kVariableWeightsFile) != ascend_context.end()) { *variable_weights_file = ascend_context.at(lite::kVariableWeightsFile); } @@ -582,115 +178,111 @@ void InsertVariableNodePass::InitWeightParam(const std::shared_ptr *node_name_map) { + const FuncGraphPtr &func_graph, const ParameterPtr ¶_node, const string &search_key, + std::unordered_map *node_name_map, + std::unordered_map *node_abstract_map) { MS_CHECK_TRUE_RET(node_name_map != nullptr, RET_ERROR); MS_CHECK_TRUE_RET(para_node != nullptr, RET_ERROR); + (*node_name_map)[search_key] = para_node->fullname_with_scope() + "_const"; + (*node_abstract_map)[search_key] = para_node->abstract(); + return RET_OK; +} - ShapeVector weight_shape; - auto ret = FetchWeightShape(para_node, &weight_shape, nullptr, is_matmul); - if (ret != lite::RET_OK) { - MS_LOG(ERROR) << "fetch wieght shape failed! ret:" << ret << "!"; - return ret; - } - TypeId weight_type = FetchTypeIdByNode(para_node); - AnfNodePtr node_weight; - if (weight_type == kNumberTypeFloat32) { - node_weight = BuildZeroVecNDParameterNode(func_graph, weight_shape, para_node->fullname_with_scope(), 0.0, - kNumberTypeFloat32); - } else if (weight_type == kNumberTypeFloat16) { - node_weight = BuildZeroVecNDParameterNode(func_graph, weight_shape, para_node->fullname_with_scope(), 0.0, - kNumberTypeFloat16); - } else if (weight_type == kNumberTypeInt32) { - node_weight = BuildZeroVecNDParameterNode(func_graph, weight_shape, para_node->fullname_with_scope(), 0.0, - kNumberTypeInt32); - } else if (weight_type == kNumberTypeInt16) { - node_weight = BuildZeroVecNDParameterNode(func_graph, weight_shape, para_node->fullname_with_scope(), 0.0, - kNumberTypeInt16); - } else { - MS_LOG(ERROR) << "replace parameter data type " << weight_type << " not supported!"; - return RET_ERROR; +FuncGraphPtr InsertVariableNodePass::CreateUpdateGraph(const std::vector &const_names, + const std::vector &abstarcts) { + MS_CHECK_TRUE_MSG(const_names.size() == abstarcts.size(), nullptr, + "size of const_names must equal to size of abstracts's size!"); + auto func_graph = std::make_shared(); + std::vector graph_outputs = {}; + for (size_t i = 0; i < const_names.size(); i++) { + auto param = func_graph->add_parameter(); + MS_CHECK_TRUE_MSG(param != nullptr, nullptr, "param is nullptr!"); + auto name = const_names[i]; + auto abstract = abstarcts[i]; + MS_CHECK_TRUE_MSG(abstract != nullptr, nullptr, "abstract is nullptr!"); + param->set_abstract(abstract->Clone()); // node name and abstract of weight + param->set_name(name + "_data"); + auto variable_prim = std::make_unique(); + MS_CHECK_TRUE_MSG(variable_prim != nullptr, nullptr, "variable_prim is nullptr!"); + variable_prim->set_type("Variable"); + std::vector variable_input_names = {"x"}; + std::vector variable_output_names = {"y"}; + variable_prim->AddAttr("input_names", api::MakeValue(variable_input_names)); + variable_prim->AddAttr("output_names", api::MakeValue(variable_output_names)); + variable_prim->AddAttr(kAttrRegOpName, api::MakeValue("Variable")); + auto variable_prim_c = variable_prim->GetPrim(); + MS_CHECK_TRUE_MSG(variable_prim_c != nullptr, nullptr, "variable_prim_c is nullptr!"); + auto variable_cnode = func_graph->NewCNode(variable_prim_c, {}); + MS_CHECK_TRUE_MSG(variable_cnode != nullptr, nullptr, "variable_cnode is nullptr"); + variable_cnode->set_fullname_with_scope(name + "_var"); + variable_cnode->set_abstract(abstract->Clone()); + auto assign_prim = std::make_unique(); + MS_CHECK_TRUE_MSG(assign_prim != nullptr, nullptr, "assign_prim is nullptr!"); + assign_prim->set_type("Assign"); + std::vector assign_input_names = {"input0", "input1"}; + std::vector assign_output_names = {"output0"}; + assign_prim->AddAttr("input_names", api::MakeValue(assign_input_names)); + assign_prim->AddAttr("output_names", api::MakeValue(assign_output_names)); + assign_prim->AddAttr(kAttrRegOpName, api::MakeValue("Assign")); + auto assign_prim_c = assign_prim->GetPrim(); + MS_CHECK_TRUE_MSG(assign_prim_c != nullptr, nullptr, "assign_prim_c is nullptr!"); + std::vector assign_inputs = {variable_cnode, param}; + auto assign_cnode = func_graph->NewCNode(assign_prim_c, assign_inputs); + MS_CHECK_TRUE_MSG(assign_cnode != nullptr, nullptr, "assign_cnode is nullptr!"); + assign_cnode->set_fullname_with_scope(name + "_assign"); + assign_cnode->set_abstract(abstract->Clone()); + graph_outputs.push_back(assign_cnode); + } + // update graph should not has output + // single output graph insert identity node as graph output, then delete it in IdentityOptimization function to make + // sure update graph has no output node. + if (graph_outputs.size() == 1) { + auto prim = std::make_unique(); + MS_CHECK_TRUE_MSG(prim != nullptr, nullptr, "prim is nullptr!"); + auto prim_c = prim->GetPrim(); + MS_CHECK_TRUE_MSG(prim_c != nullptr, nullptr, "prim_c is nullptr!"); + auto identity_cnode = func_graph->NewCNode(prim_c, graph_outputs); + MS_CHECK_TRUE_MSG(identity_cnode != nullptr, nullptr, "identity_cnode is nullptr"); + identity_cnode->set_abstract(graph_outputs[0]->abstract()->Clone()); + identity_cnode->set_fullname_with_scope(graph_outputs[0]->fullname_with_scope() + "_identity"); + graph_outputs[0] = identity_cnode; + } + auto ret = BuildReturnNode(func_graph, graph_outputs); + if (ret != kSuccess) { + MS_LOG(ERROR) << "BuildReturnNode failed!"; + return nullptr; } - - MS_CHECK_TRUE_MSG(node_weight != nullptr, RET_ERROR, "node_weight is nullptr!"); - (*node_name_map)[search_key + "variable_up"] = para_node->fullname_with_scope() + "_const"; - auto manager = Manage(func_graph); - MS_CHECK_TRUE_RET(manager != nullptr, RET_ERROR); - (void)manager->Replace(para_node, node_weight); - return RET_OK; + func_graph->set_attr("is_update_graph", MakeValue(true)); + return func_graph; } -lite::STATUS InsertVariableNodePass::BuildVariableNode(const std::shared_ptr ¶m, - FuncGraphPtr func_graph, std::vector *const_names) { +lite::STATUS InsertVariableNodePass::BuildVariableNode(FuncGraphPtr func_graph) { MS_CHECK_TRUE_RET(func_graph != nullptr, RET_ERROR); std::string variable_weights_file = ""; int32_t max_weight_batch = 1; - InitWeightParam(param, &variable_weights_file, &max_weight_batch); + InitWeightParam(&variable_weights_file, &max_weight_batch); MS_CHECK_TRUE_RET(variable_weights_file != "", RET_OK); - bool has_alpha = false; - std::map> variable_nodes; + std::set variable_nodes; std::unordered_map node_name_map; + std::unordered_map node_abstract_map; std::vector node_name_list; - auto ret = ParseInsertNode(variable_weights_file, &variable_nodes, &node_name_map, &node_name_list, &has_alpha); + auto ret = ParseInsertNode(variable_weights_file, &variable_nodes, &node_name_list); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "ParseInsertNode failed!"); uint32_t matched_num = 0; auto node_list = TopoSort(func_graph->get_return()); for (auto &node : node_list) { MS_CHECK_TRUE_RET(node != nullptr, false); auto node_name = node->fullname_with_scope(); - size_t last_slash_pos = node_name.find_last_of('/'); - std::string search_key = ""; if (utils::isa(node)) { - search_key = node_name; - if (variable_nodes.find(search_key) == variable_nodes.end()) { + if (variable_nodes.find(node_name) == variable_nodes.end()) { continue; } auto parameter = node->cast(); if (parameter == nullptr || !parameter->has_default()) { continue; } - ret = RecordParameterVariableName(func_graph, parameter, search_key, false, &node_name_map); + ret = RecordParameterVariableName(func_graph, parameter, node_name, &node_name_map, &node_abstract_map); MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record parameter variable name failed!"); - } else if (utils::isa(node)) { - if (last_slash_pos != std::string::npos) { - search_key = node_name.substr(0, last_slash_pos); - } else { - MS_LOG(INFO) << "Not found last slash, Cnode name:" << node->fullname_with_scope() << "!"; - continue; - } - if (variable_nodes.find(search_key) == variable_nodes.end()) { - continue; - } - auto cnode = utils::cast(node); - MS_CHECK_TRUE_RET(cnode != nullptr, false); - if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulV2) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimMatMulFusion) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimBatchMatMul)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), true, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, true, &node_name_map); - } else { - ret = InsertVariableNodeForMatmul(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else if (mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2D) || - mindspore::opt::CheckPrimitiveType(node, mindspore::prim::kPrimConv2DFusion)) { - bool replace_origin = false; - ret = CheckOnlyReplace(cnode, variable_nodes.at(search_key), false, &replace_origin); - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "CheckOnlyReplace failed!"); - if (replace_origin) { - ret = RecordVariableName(func_graph, cnode, search_key, false, &node_name_map); - } else { - ret = InsertVariableNodeForConv(node, cnode, func_graph, variable_nodes.at(search_key), &node_name_map, - has_alpha, max_weight_batch); - } - MS_CHECK_TRUE_MSG(ret == RET_OK, ret, "Record variable name failed!"); - } else { - continue; - } } else { continue; } @@ -704,16 +296,24 @@ lite::STATUS InsertVariableNodePass::BuildVariableNode(const std::shared_ptrpush_back(node_name_map[s]); + param_->variable_node_names.push_back(node_name_map[s]); + param_->variable_node_abstracts.push_back(node_abstract_map[s]); + } + auto update_graph = CreateUpdateGraph(param_->variable_node_names, param_->variable_node_abstracts); + if (update_graph == nullptr) { + MS_LOG(ERROR) << "update graph is nullptr!"; + return RET_ERROR; } + param_->update_graph = update_graph; return RET_OK; } bool InsertVariableNodePass::Run(const FuncGraphPtr &graph) { - if (BuildVariableNode(param_, graph, &(param_->const_names)) != RET_OK) { + if (BuildVariableNode(graph) != RET_OK) { + MS_LOG(ERROR) << "build variable node failed!"; return false; } - if (param_->const_names.size() > 0) { + if (param_->variable_node_names.size() > 0) { graph->set_attr(lite::kBundleModel, MakeValue("True")); } return true; diff --git a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h index ba623d67..c2bbb52e 100644 --- a/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h +++ b/mindspore-lite/tools/optimizer/graph/add_variable_node_pass.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "include/backend/optimizer/pass.h" #include "tools/converter/cxx_api/converter_para.h" #include "include/errorcode.h" @@ -36,35 +37,19 @@ class InsertVariableNodePass : public Pass { bool Run(const FuncGraphPtr &graph) override; private: - lite::STATUS BuildVariableNode(const std::shared_ptr ¶m, FuncGraphPtr func_graph, - std::vector *const_names); - lite::STATUS InsertVariableNodeForMatmul(const AnfNodePtr &node, const CNodePtr &cnode, - const FuncGraphPtr &func_graph, const std::vector &up_shape, - std::unordered_map *node_name_map, bool has_alpha, - int max_weight_batch); - lite::STATUS InsertVariableNodeForConv(const AnfNodePtr &node, const CNodePtr &cnode, const FuncGraphPtr &func_graph, - const std::vector &up_shape, - std::unordered_map *node_name_map, bool has_alpha, - int max_weight_batch); - lite::STATUS ParseInsertNode(std::string file_path, std::map> *variable_nodes, - std::unordered_map *node_name_map, - std::vector *node_name_list, bool *has_alpha); - lite::STATUS ParseShapeStr(std::string shape_str, std::vector *shape); - lite::STATUS InsertVariableAddNode(const CNodePtr &cnode, const FuncGraphPtr &func_graph, const bool &is_matmul, - std::unordered_map *node_name_map); - lite::STATUS CheckOnlyReplace(CNodePtr cnode, const std::vector ¶_shape, const bool &is_matmul, - bool *compare_res); - lite::STATUS RecordVariableName(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const string &search_key, - bool is_matmul, std::unordered_map *node_name_map); - lite::STATUS RecordParameterVariableName(const FuncGraphPtr &func_graph, const ParameterPtr ¶_node, - const string &search_key, bool is_matmul, - std::unordered_map *node_name_map); + lite::STATUS BuildVariableNode(FuncGraphPtr func_graph); + lite::STATUS ParseInsertNode(std::string file_path, std::set *variable_nodes, + std::vector *node_name_list); template ParameterPtr BuildZeroVecNDParameterNode(const FuncGraphPtr &anf_graph, ShapeVector weight_shape, const std::string &node_name, T value, TypeId dtype); - void InitWeightParam(const std::shared_ptr ¶m, std::string *variable_weights_file, - int32_t *max_weight_batch); - + void InitWeightParam(std::string *variable_weights_file, int32_t *max_weight_batch); + FuncGraphPtr CreateUpdateGraph(const std::vector &const_names, + const std::vector &abstarcts); + lite::STATUS RecordParameterVariableName(const FuncGraphPtr &func_graph, const ParameterPtr ¶_node, + const string &search_key, + std::unordered_map *node_name_map, + std::unordered_map *node_abstract_map); std::shared_ptr param_; }; } // namespace opt -- Gitee