diff --git a/mindspore/ccsrc/frontend/jit/ps/action.cc b/mindspore/ccsrc/frontend/jit/ps/action.cc index 207b3e3acb3361c85d511ee0c8224cdd26e2e19e..c1e5edd033515ec7682191deb205216d3ffdce9c 100644 --- a/mindspore/ccsrc/frontend/jit/ps/action.cc +++ b/mindspore/ccsrc/frontend/jit/ps/action.cc @@ -2174,6 +2174,8 @@ std::vector VmPipeline(const ResourcePtr &resource, bool trace_flag, // Mind Compiler finish. (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction)); + + (void)actions.emplace_back(std::make_pair(kBackendPass, BackendPass)); } if (erase_parse) { @@ -2191,8 +2193,6 @@ std::vector VmPipeline(const ResourcePtr &resource, bool trace_flag, return actions; } - (void)actions.emplace_back(std::make_pair(kBackendPass, BackendPass)); - auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); #ifndef WITH_BACKEND @@ -2225,7 +2225,9 @@ std::vector MindIRPipeline() { std::vector JitPipeline(const ResourcePtr &resource, bool build_top_graph) { std::vector jit_passes; - if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) { + bool init_null = resource->func_graph() == nullptr; + bool is_jit_grad = pynative::GradState::Get().RequiresGrad(); + if (!resource->EnableCompileCache() || init_null) { // Compile the frontend graph. if (build_top_graph) { (void)jit_passes.emplace_back(kBootstrap, BootstrapAction); @@ -2263,6 +2265,7 @@ std::vector JitPipeline(const ResourcePtr &resource, bool build_top_gr } (void)jit_passes.emplace_back(kSymbolEngineOpt, SymbolEngineOptGroup); (void)jit_passes.emplace_back(kValidate, ValidatePass); + (void)jit_passes.emplace_back(std::make_pair(kBackendPass, BackendPass)); } auto is_precompile_only = MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY) || @@ -2272,11 +2275,15 @@ std::vector JitPipeline(const ResourcePtr &resource, bool build_top_gr return jit_passes; } - if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) { + if (resource->EnableCompileCache() && !init_null && is_jit_grad) { + // Store forward graph for jit grad when using compile cache. + (void)jit_passes.emplace_back(kGetJitBpropGraph, GetJitBpropGraph); return jit_passes; } - (void)jit_passes.emplace_back(std::make_pair(kBackendPass, BackendPass)); + if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) { + return jit_passes; + } #ifndef WITH_BACKEND auto ms_context = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.cc b/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.cc index ffe1ab9568c369f9e1eb770d3495312c5944dc93..2dff1eba14a8504e7c83e5dc2b422359d24c89cf 100644 --- a/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.cc +++ b/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.cc @@ -118,8 +118,9 @@ std::string GetRole() { return ""; } -std::string GetCompileCachePath(size_t idx) { - return GetGraphCacheDir() + "/" + GetRole() + kCompileCacheFileName + "_" + std::to_string(idx) + kMindIrSuffix; +std::string GetCompileCachePath(const std::string &id_extension, size_t idx) { + return GetGraphCacheDir() + "/" + GetRole() + kCompileCacheFileName + "_" + std::to_string(idx) + id_extension + + kMindIrSuffix; } std::string GetBackendCompileCachePathWithoutExtension(size_t idx) { @@ -173,11 +174,11 @@ std::map GenerateWeightsValueMap(const py::dict &weights) { return ret; } -std::pair LoadFuncGraphFromMindIR(const py::dict &weights, bool has_parallel_info, - size_t idx) { +std::pair LoadFuncGraphFromMindIR(const py::dict &weights, const std::string &id_extension, + bool has_parallel_info, size_t idx) { MsProfileStatGuard stat_guard("LoadFuncGraphFromMindIR", "compile_cache", true); LayoutMap layout_map; - std::string compile_cache_path = GetCompileCachePath(idx); + std::string compile_cache_path = GetCompileCachePath(id_extension, idx); auto realpath = Common::CreatePrefixPath(compile_cache_path, true); if (!realpath.has_value()) { MS_LOG(ERROR) << "Get real path of file " << compile_cache_path << " failed."; @@ -213,8 +214,9 @@ std::pair LoadFuncGraphFromMindIR(const py::dict &weigh return std::make_pair(fg, mindir_loader.layout_map()); } -bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, size_t idx) { - std::string compile_cache_path = GetCompileCachePath(idx); +bool ExportFuncGraphToMindIR(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, const std::string &id_extension, + size_t idx) { + std::string compile_cache_path = GetCompileCachePath(id_extension, idx); auto proto = GenBinaryProto(fg); if (proto == nullptr) { MS_LOG(ERROR) << "Get binary proto for graph " << fg->ToString() << " failed."; @@ -370,25 +372,31 @@ std::string CompileCacheManager::GetCachedDataQueueName(const std::string &datas return queue_name; } -void CompileCacheManager::CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg) { +void CompileCacheManager::CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, bool cache_hash, + bool backward_graph) { if (fg == nullptr) { MS_LOG(ERROR) << "The func_graph to be cached is null."; return; } - const auto &queue_name = GetDataQueueName(fg); - auto dataset_phase = ConfigManager::GetInstance().dataset_phase(); - if (!ExportDataQueueName(dataset_phase, queue_name)) { - MS_LOG(ERROR) << "Failed to cache data queue name: " << queue_name; - return; + if (!backward_graph) { + const auto &queue_name = GetDataQueueName(fg); + auto dataset_phase = ConfigManager::GetInstance().dataset_phase(); + if (!ExportDataQueueName(dataset_phase, queue_name)) { + MS_LOG(ERROR) << "Failed to cache data queue name: " << queue_name; + return; + } } SetCompileCacheDir(GetCompileCacheDir()); - if (!ExportFuncGraphToMindIR(fg, layout_fg, compile_cache_id_)) { + if (!ExportFuncGraphToMindIR(fg, layout_fg, id_extension_, compile_cache_id_)) { MS_LOG(ERROR) << "Failed to cache graph: " << fg->ToString(); return; } + if (!cache_hash) { + return; + } if (compile_cache_id_ == 0 && !ExportDepFilesHash(compile_cache_dep_files_hash_)) { MS_LOG(ERROR) << "Failed to cache the dependency files hash"; } @@ -427,7 +435,7 @@ bool CompileCacheManager::CanLoadCache() { MS_LOG(WARNING) << "The compilation dependency files are changed."; return false; } - auto compile_cache_path = GetCompileCachePath(compile_cache_id_); + auto compile_cache_path = GetCompileCachePath(id_extension_, compile_cache_id_); struct stat buffer; if (stat(compile_cache_path.c_str(), &buffer) != 0) { MS_LOG(WARNING) << "Failed to find cache file, execute all the compilation actions."; @@ -437,7 +445,7 @@ bool CompileCacheManager::CanLoadCache() { } FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights, - const std::string &queue_name) { + const std::string &queue_name, bool backward_graph) { // Determine whether to load parallel information. std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); bool has_parallel_info = false; @@ -450,7 +458,7 @@ FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr & has_parallel_info = true; } // Load the compilation cache file. - auto pair = LoadFuncGraphFromMindIR(weights, has_parallel_info, compile_cache_id_); + auto pair = LoadFuncGraphFromMindIR(weights, id_extension_, has_parallel_info, compile_cache_id_); if (pair.first == nullptr) { MS_LOG(WARNING) << "Failed to load the compilation cache file. Execute all the compilation actions."; return nullptr; @@ -461,24 +469,26 @@ FuncGraphPtr CompileCacheManager::GetCachedFuncGraph(const FuncGraphManagerPtr & MS_LOG(WARNING) << "Use the compilation cache and execute the backend actions only. Be aware of correctness risks."; FuncGraphManagerPtr mng = fg->manager(); if (mng == nullptr) { - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(fg); - fg->set_manager(manager); - } - // The value of attr "shared_name" will changed every time. - auto cnodes = fg->GetOrderedCnodes(); - for (const auto &cnode : cnodes) { - auto prim = GetValuePtr(cnode->input(0)); - if (prim != nullptr && prim->HasAttr("shared_name")) { - prim->set_attr("shared_name", MakeValue(queue_name)); - break; + auto new_manager = manager == nullptr ? MakeManager({fg}, false) : manager; + new_manager->AddFuncGraph(fg); + fg->set_manager(new_manager); + } + if (!backward_graph) { + // The value of attr "shared_name" will changed every time. + auto cnodes = fg->GetOrderedCnodes(); + for (const auto &cnode : cnodes) { + auto prim = GetValuePtr(cnode->input(0)); + if (prim != nullptr && prim->HasAttr("shared_name")) { + prim->set_attr("shared_name", MakeValue(queue_name)); + break; + } } } #ifdef ENABLE_DUMP_IR auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); if (context->CanDump(kIntroductory)) { - DumpIR("cache_loaded_graph_" + std::to_string(compile_cache_id_) + ".ir", fg); + DumpIR("cache_loaded_graph_" + std::to_string(compile_cache_id_) + id_extension_ + ".ir", fg); } #endif return fg; diff --git a/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.h b/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.h index f3f4f84ac88df08d395f43dc26c6ae6c328f40d6..526f64e0183624ed66561f08d9de04d093946682 100644 --- a/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.h +++ b/mindspore/ccsrc/frontend/jit/ps/compile_cache_manager.h @@ -46,14 +46,16 @@ class CompileCacheManager { bool CanLoadCache(); // Load the cached func_graph from mindir file. FuncGraphPtr GetCachedFuncGraph(const FuncGraphManagerPtr &manager, const py::dict &weights, - const std::string &queue_name); + const std::string &queue_name, bool backward_graph = false); // Export the func_graph to mindir file. - void CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg); + void CacheFuncGraph(const FuncGraphPtr &fg, const FuncGraphPtr &layout_fg, bool cache_hash = true, + bool backward_graph = false); const LayoutMap &layout_map() const { return layout_map_; } void SetCompileCacheDir(const std::string &dir) { compile_cache_dir_ = dir; } std::string CompileCacheDir() const { return compile_cache_dir_; } + void SetIdExtension(const std::string &id_extension) { id_extension_ = id_extension; } static size_t data_queue_num_; private: @@ -61,6 +63,7 @@ class CompileCacheManager { std::string compile_cache_dep_files_hash_; LayoutMap layout_map_; std::string compile_cache_dir_; + std::string id_extension_; }; using CompileCacheManagerPtr = std::shared_ptr; } // namespace pipeline diff --git a/mindspore/ccsrc/frontend/jit/ps/executor/executor_py.h b/mindspore/ccsrc/frontend/jit/ps/executor/executor_py.h index 67f21418f501c21bb45037f80e9ce113439e2869..35e60e9db1d21015412f02f714f3605213453bf1 100644 --- a/mindspore/ccsrc/frontend/jit/ps/executor/executor_py.h +++ b/mindspore/ccsrc/frontend/jit/ps/executor/executor_py.h @@ -61,7 +61,9 @@ class FRONTEND_EXPORT ExecutorPy : public std::enable_shared_from_this &passes) bool already_print_profile = false; ProfileExecute(MsProfile::GetProfile(), [&resource, &passes, &already_print_profile]() { static const std::string last_compile_action = kValidate; + static const std::string jit_grad_last_compile_action = kGetJitBpropGraph; + bool is_jit_grad = pynative::GradState::Get().RequiresGrad(); static const auto compile_profile_finish_action = common::GetCompileConfig("COMPILE_PROFILE_FINISH_ACTION"); size_t counter = 0; for (auto &pass : passes) { @@ -89,11 +92,15 @@ void Optimize(const ResourcePtr &resource, const std::vector &passes) }; ProfileExecute(profile_context, pass_func); ProcessStatus::GetInstance().RecordEnd(); + if (pass.first == jit_grad_last_compile_action && is_jit_grad) { + CacheFuncGraph(resource); + } else if (pass.first == last_compile_action && !is_jit_grad) { + CacheFuncGraph(resource); + } if (pass.first == kTaskEmit) { SetLoopCount(resource); } else if (pass.first == last_compile_action) { CheckInterpretNodeLineInfos(); - CacheFuncGraph(resource); ResetId(resource); } else if (pass.first == kAutoMonadReorder) { resource->set_optimize_graph(resource->func_graph()); diff --git a/mindspore/ccsrc/frontend/jit/ps/pipeline.cc b/mindspore/ccsrc/frontend/jit/ps/pipeline.cc index 836d84d9023c4915410835063700a14bd9041c53..99cee90eee4b9004e096b787b8e8299586543a7a 100644 --- a/mindspore/ccsrc/frontend/jit/ps/pipeline.cc +++ b/mindspore/ccsrc/frontend/jit/ps/pipeline.cc @@ -569,9 +569,11 @@ void Pipeline::Run() { MS_EXCEPTION_IF_NULL(resource_); FuncGraphPtr user_graph = nullptr; const std::string last_compile_action = kValidate; + const std::string last_compile_action_for_compile_cache = kBackendPass; bool already_print_profile = false; static const auto compile_profile_finish_action = common::GetCompileConfig("COMPILE_PROFILE_FINISH_ACTION"); - ProfileExecute(MsProfile::GetProfile(), [this, &user_graph, &last_compile_action, &already_print_profile]() { + ProfileExecute(MsProfile::GetProfile(), [this, &user_graph, &last_compile_action, + &last_compile_action_for_compile_cache, &already_print_profile]() { size_t i = 0; for (auto &action : actions_) { std::string action_name = action.first; @@ -618,8 +620,9 @@ void Pipeline::Run() { SetLoopCount(resource_); } else if (action.first == last_compile_action) { CheckInterpretNodeLineInfos(); - CacheFuncGraph(resource_); ResetId(resource_); + } else if (action.first == last_compile_action_for_compile_cache) { + CacheFuncGraph(resource_); } FuncGraphPtr graph = resource_->func_graph(); #ifdef ENABLE_DUMP_IR diff --git a/mindspore/ccsrc/frontend/optimizer/ad/pynative_jit_grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/pynative_jit_grad.cc index fea70ceeab91e44d7f0908c2c42278a62d85742b..f5e05b38eac361824205d689954ba275f9446c58 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/pynative_jit_grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/pynative_jit_grad.cc @@ -40,6 +40,10 @@ #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_t.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_u.h" #include "mindspore/ops/op_def/auto_generate/gen_ops_primitive_v.h" +#include "include/common/utils/compile_cache_context.h" +#include "frontend/jit/ps/compile_cache_manager.h" +#include "frontend/jit/ps/executor/graph_executor_py.h" +#include "frontend/jit/ps/executor/jit_executor_py.h" namespace mindspore { namespace ad { mindspore::HashMap> pass_grad_graph_; @@ -389,6 +393,78 @@ bool IsViewInplaceAbs(const AbstractBasePtr &abs) { } } // namespace +std::pair GetCompileCacheResource(const py::dict &weights, + const std::string &extension, + size_t compile_cache_id, bool backward_graph) { + pipeline::CompileCacheManagerPtr compile_cache_manager = + std::make_shared(compile_cache_id); + compile_cache_manager->SetIdExtension(extension); + compile_cache_manager->InitParallelGroupCkptSaveFile(); + const bool force_use_compile_cache = (common::GetEnv("MS_DEV_FORCE_USE_COMPILE_CACHE") == "1"); + auto &context = CompileCacheContext::GetInstance(); + auto jit_executor = pipeline::JitExecutorPy::GetInstance(); + const py::list &compile_cache_dep_files = jit_executor->get_compile_cache_dep_files(); + // When enabling compile cache, it is possible to enable it even without Python script. + if (force_use_compile_cache || compile_cache_dep_files.empty()) { + context.set_init_compile_cache(true); + MS_LOG(WARNING) + << "The env MS_DEV_FORCE_USE_COMPILE_CACHE has been set. It will force to use the compile cache without " + "checking whether the network has been changed. Please note the correctness."; + } else { + MsProfileStatGuard stat_guard("InitCompileCache", "compile_cache", true); + if (!common::UseHostCollective()) { + context.set_init_compile_cache(true); + } + bool compile_cache_consistent = jit_executor->GetCompileCacheConsistent(); + if (!compile_cache_consistent) { + MS_LOG(WARNING) << "Check the consistency of dependency files hash failed. Execute all the compilation actions."; + return std::make_pair(nullptr, compile_cache_manager); + } + } + auto manager = MakeManager({}, true, true); + FuncGraphPtr func_graph = compile_cache_manager->GetCachedFuncGraph(manager, weights, "", backward_graph); + return std::make_pair(func_graph, compile_cache_manager); +} + +VectorRef ExecuteForward(const pynative::GradParamPtr &grad_param, const FuncGraphPtr &forward_fg, + const bool need_forward_result, const bool need_reuse_forward_node, const bool cache_hit) { + // 2. Execute forward graph if needed + // Prepare argument list for graph execution + VectorRef arg_list; + std::transform(grad_param->op_grad_info->input_value.begin(), grad_param->op_grad_info->input_value.end(), + std::back_inserter(arg_list), [](const ValuePtr &value) { return value; }); + ValuePtr forward_output_value = grad_param->op_grad_info->out_value; + AbstractBasePtr origin_forward_output_abs = grad_param->op_grad_info->out_abs; + MS_EXCEPTION_IF_NULL(origin_forward_output_abs); + MS_EXCEPTION_IF_NULL(forward_fg); + if (need_forward_result) { + MS_LOG(INFO) << "Start run forward graph result"; + const auto &output = forward_fg->output(); + MS_EXCEPTION_IF_NULL(output); + const auto &output_abs = output->abstract(); + MS_EXCEPTION_IF_NULL(output_abs); + if (need_reuse_forward_node) { + // {prim::kPrimMakeTuple, origin_forward_output, {prim::kPrimMakeTuple, reuse_cnode1, reuse_cnode2, ...}} + auto tuple_output_abstract = output_abs->cast(); + if (tuple_output_abstract == nullptr || tuple_output_abstract->size() == 0) { + MS_LOG(EXCEPTION) << "Invalid output abstract: " << output_abs->ToString(); + } + auto node_abstracts = tuple_output_abstract->elements(); + node_abstracts[kIndex0] = origin_forward_output_abs; + output->set_abstract(std::make_shared(node_abstracts)); + } else { + output->set_abstract(origin_forward_output_abs); + } + auto forward_result = GetGraphResult(forward_fg, arg_list, cache_hit, grad_param->graph_cache_key); + py::object py_forward_result = + HandleForwardResult(forward_result, forward_fg, origin_forward_output_abs, grad_param, need_reuse_forward_node); + MS_LOG(DEBUG) << "Run forward graph get result: " << py::str(py_forward_result); + forward_output_value = PyObjToValue(py_forward_result); + grad_param->op_grad_info->out_value = forward_output_value; + } + return arg_list; +} + std::pair GetBpropGraph(const pynative::GradParamPtr &grad_param) { MS_EXCEPTION_IF_NULL(grad_param); MS_EXCEPTION_IF_NULL(grad_param->op_grad_info); @@ -409,10 +485,32 @@ std::pair GetBpropGraph(const pynative::GradParamPtr &grad_p // 1. Check cache for existing graphs const auto it = pass_grad_graph_.find(grad_param->graph_cache_key); bool cache_hit = it != pass_grad_graph_.end(); + pipeline::CompileCacheManagerPtr compile_cache_manager = nullptr; + pipeline::CompileCacheManagerPtr compile_cache_manager_forward = nullptr; + bool loaded = false; + if (CompileCacheEnable() && !cache_hit) { + auto graph_executor = pipeline::GraphExecutorPy::GetInstance(); + const auto &weights = graph_executor->get_weights_values(); + { + MsProfileStatGuard stat_guard("LoadCachedFuncGraph"); + static size_t idx = 0; + auto pair = GetCompileCacheResource(weights, "grad", idx++, true); + after_opt_fg = pair.first; + compile_cache_manager = pair.second; + } + { + MsProfileStatGuard stat_guard("LoadCachedFuncGraph"); + static size_t idx_forward = 0; + auto pair_forward = GetCompileCacheResource(weights, "grad_forward", idx_forward++, false); + forward_fg = pair_forward.first; + compile_cache_manager_forward = pair_forward.second; + } + loaded = after_opt_fg != nullptr && forward_fg != nullptr; + } if (cache_hit) { MS_LOG(DEBUG) << "Get ad grad graph by cache, cache key: " << grad_param->graph_cache_key; std::tie(forward_fg, after_opt_fg) = it->second; - } else { + } else if (!loaded) { // Generate backward graph and forward graph with reused cnode as output jit_adgrad_processer = std::make_shared( BasicClone(grad_param->fg), grad_param->op_grad_info->input_abs, grad_param->op_grad_info->input_value, @@ -435,40 +533,9 @@ std::pair GetBpropGraph(const pynative::GradParamPtr &grad_p pynative::CommonUtils::DumpGraphIR("opt_forward.ir", forward_fg); } - // 2. Execute forward graph if needed - // Prepare argument list for graph execution - VectorRef arg_list; - std::transform(grad_param->op_grad_info->input_value.begin(), grad_param->op_grad_info->input_value.end(), - std::back_inserter(arg_list), [](const ValuePtr &value) { return value; }); + VectorRef arg_list = ExecuteForward(grad_param, forward_fg, need_forward_result, need_reuse_forward_node, cache_hit); ValuePtr forward_output_value = grad_param->op_grad_info->out_value; AbstractBasePtr origin_forward_output_abs = grad_param->op_grad_info->out_abs; - MS_EXCEPTION_IF_NULL(origin_forward_output_abs); - MS_EXCEPTION_IF_NULL(forward_fg); - if (need_forward_result) { - MS_LOG(INFO) << "Start run forward graph result"; - const auto &output = forward_fg->output(); - MS_EXCEPTION_IF_NULL(output); - const auto &output_abs = output->abstract(); - MS_EXCEPTION_IF_NULL(output_abs); - if (need_reuse_forward_node) { - // {prim::kPrimMakeTuple, origin_forward_output, {prim::kPrimMakeTuple, reuse_cnode1, reuse_cnode2, ...}} - auto tuple_output_abstract = output_abs->cast(); - if (tuple_output_abstract == nullptr || tuple_output_abstract->size() == 0) { - MS_LOG(EXCEPTION) << "Invalid output abstract: " << output_abs->ToString(); - } - auto node_abstracts = tuple_output_abstract->elements(); - node_abstracts[kIndex0] = origin_forward_output_abs; - output->set_abstract(std::make_shared(node_abstracts)); - } else { - output->set_abstract(origin_forward_output_abs); - } - auto forward_result = GetGraphResult(forward_fg, arg_list, cache_hit, grad_param->graph_cache_key); - py::object py_forward_result = - HandleForwardResult(forward_result, forward_fg, origin_forward_output_abs, grad_param, need_reuse_forward_node); - MS_LOG(DEBUG) << "Run forward graph get result: " << py::str(py_forward_result); - forward_output_value = PyObjToValue(py_forward_result); - grad_param->op_grad_info->out_value = forward_output_value; - } // 3. Update grad_param info about forward output value grad_param->args = arg_list; @@ -483,8 +550,20 @@ std::pair GetBpropGraph(const pynative::GradParamPtr &grad_p // 4. Store forward_graph and bprop if (!cache_hit) { - jit_adgrad_processer->SetForwardOutputAbs(grad_param->op_grad_info->out_abs, after_opt_fg); - pynative::CommonUtils::DumpGraphIR("opt_backward.ir", after_opt_fg); + if (!CompileCacheEnable() || !loaded) { + jit_adgrad_processer->SetForwardOutputAbs(grad_param->op_grad_info->out_abs, after_opt_fg); + pynative::CommonUtils::DumpGraphIR("opt_backward.ir", after_opt_fg); + } + if (CompileCacheEnable() && !loaded) { + { + MsProfileStatGuard stat_guard("SaveCacheFuncGraph", "compile_cache", true); + compile_cache_manager->CacheFuncGraph(after_opt_fg, nullptr, false, true); + } + { + MsProfileStatGuard stat_guard("SaveCacheFuncGraph", "compile_cache", true); + compile_cache_manager_forward->CacheFuncGraph(forward_fg, nullptr, false, true); + } + } if (grad_param->is_jit_graph) { pass_grad_graph_[grad_param->graph_cache_key] = {forward_fg, after_opt_fg}; } @@ -498,6 +577,7 @@ void ClearGradCache() { check_invalid_dout_bprop_graph.clear(); origin_grad_graph_.clear(); filtered_grad_graph.clear(); + CompileCacheContext::GetInstance().Clear(); } void BpropGenerator::ReuseCustomBpropForwardOutput(const FuncGraphPtr &k_fg, const FuncGraphPtr &top_fg) { diff --git a/mindspore/core/load_mindir/load_model.cc b/mindspore/core/load_mindir/load_model.cc index b2c565ec43e1235af931a044309163418af58a95..69c543e6832a774109e05473ce14c6c3062b79cf 100644 --- a/mindspore/core/load_mindir/load_model.cc +++ b/mindspore/core/load_mindir/load_model.cc @@ -459,7 +459,7 @@ class MSANFModelParser { bool little_endian_ = common::IsLittleByteOrder(); std::map> tenor_data_; bool is_kernel_graph_{false}; - std::list> node_abstract_protos_; + std::list> node_abstract_protos_; }; ValuePtr MSANFModelParser::GenerateTensorValue(const mind_ir::TensorProto &tensor_proto) { @@ -1157,7 +1157,8 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi } else { auto abs = GetNodeAbstractFromAttrProtoWithType(attr_proto); if (abs == nullptr) { - MS_LOG(ERROR) << "Failed to get abstract for input node " << node->name() + node_abstract_protos_.push_back(std::pair(node, &value_proto.attr_info())); + MS_LOG(DEBUG) << "Failed to get abstract for input node " << node->name() << " from attr_proto:" << attr_proto.DebugString(); } node->set_abstract(abs); @@ -2203,16 +2204,30 @@ bool MSANFModelParser::SetValueForTopGraphParameter(const FuncGraphPtr &topGraph } void MSANFModelParser::TrytoBuildCNodeAbstract() { - std::map visited_times; + std::map visited_times; constexpr int kMaxCount = 3; while (!node_abstract_protos_.empty()) { auto &item = node_abstract_protos_.front(); - auto &count = visited_times[item.first]; + const auto &node = item.first; + auto &count = visited_times[node]; if (count++ > kMaxCount) { abstract_valid_ = false; - MS_LOG(ERROR) << "Parse CNode: " << item.first->ToString() << " abstract error: " << item.second->DebugString(); + MS_LOG(ERROR) << "Parse CNode: " << node->ToString() << " abstract error: " << item.second->DebugString(); } else { - SetCNodeAbstract(*(item.second), item.first); + if (node->isa()) { + const auto &cnode = node->cast(); + SetCNodeAbstract(*(item.second), cnode); + } else if (node->isa()) { + auto abs = GetNodeAbstractFromAttrProtoWithType(*(item.second)); + if (abs == nullptr) { + node_abstract_protos_.push_back(std::pair(node, item.second)); + MS_LOG(ERROR) << "Failed to get abstract for input node " << node->DebugString() + << " from attr_proto:" << item.second->DebugString(); + } + node->set_abstract(abs); + } else { + MS_LOG(ERROR) << "Trying to rebuild unexpected node: " << node->DebugString(); + } } node_abstract_protos_.pop_front(); } diff --git a/tests/st/compiler/compile_cache/run_compile_cache_custom_bprop.py b/tests/st/compiler/compile_cache/run_compile_cache_custom_bprop.py new file mode 100644 index 0000000000000000000000000000000000000000..ca54cbcb8081860fc504d26cb1847d9c4fcdbecc --- /dev/null +++ b/tests/st/compiler/compile_cache/run_compile_cache_custom_bprop.py @@ -0,0 +1,39 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.common import Tensor +from mindspore import context, jit, nn +from mindspore.ops.composite import GradOperation +from mindspore.common import dtype as mstype + + +class Net(nn.Cell): + @jit + def construct(self, x, y): + z = x * y + z = z * y + return z + + def bprop(self, x, y, out, dout): + x_dout = x + y + y_dout = x * y + return x_dout, y_dout, out, dout + + +context.set_context(mode=context.PYNATIVE_MODE) +grad_all = GradOperation(get_all=True) +res = grad_all(Net())(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) +print("AAA", res, "BBB") +print("AAA", res[0].asnumpy().shape, "BBB") diff --git a/tests/st/compiler/compile_cache/run_compile_cache_grad_jit.py b/tests/st/compiler/compile_cache/run_compile_cache_grad_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2bc61a2312514792e3dbb1b0dc3e7e73ccf5c8 --- /dev/null +++ b/tests/st/compiler/compile_cache/run_compile_cache_grad_jit.py @@ -0,0 +1,32 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from mindspore.common import Tensor +from mindspore import context, jit +from mindspore.ops.composite import GradOperation + + +@jit +def func(x, y): + x = x * 3 + return 2 * x[0] + y + + +context.set_context(mode=context.PYNATIVE_MODE) +a = Tensor([1, 2, 3]) +b = Tensor([1, 1, 1]) +res = GradOperation()(func)(a, b) +print("AAA", res, "BBB") +print("AAA", res.asnumpy().shape, "BBB") diff --git a/tests/st/compiler/compile_cache/run_high_grad_jit.py b/tests/st/compiler/compile_cache/run_high_grad_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..88d4145a183d993dc47b2176384469664c826b98 --- /dev/null +++ b/tests/st/compiler/compile_cache/run_high_grad_jit.py @@ -0,0 +1,59 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +from mindspore.common import Tensor +from mindspore.common import dtype as mstype +from mindspore import nn, jit +from mindspore.ops import composite as C + +class Net(nn.Cell): + def __init__(self, num_layer): + super().__init__() + self.layers = nn.CellList() + self.dense = nn.Dense(4, 4) + for _ in range(num_layer): + self.layers.append(nn.ReLU()) + self.flatten = nn.Flatten() + + @jit + def construct(self, x): + out = x + out = self.dense(x) + for layer in self.layers: + out = layer(out) + out = self.flatten(out) + return out + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = C.GradOperation(get_all=True, sens_param=False) + self.network = network + + def construct(self, x): + gout = self.grad(self.network)(x) + return gout + +net = Net(100) +grad_net = Grad(net) +d = Tensor(shape=[None, None], dtype=mstype.float32) +grad_net.set_inputs(d) + +input_X = Tensor(np.random.randn(4, 4).astype(np.float32)) +ggrad_net = Grad(grad_net) +res = ggrad_net(input_X) +print("AAA", res, "BBB") +print("AAA", res[0].asnumpy().shape, "BBB") diff --git a/tests/st/compiler/compile_cache/run_lenet_with_jit.py b/tests/st/compiler/compile_cache/run_lenet_with_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..add8cf6312a876f08c7bb6128a0384bbd1fa2b62 --- /dev/null +++ b/tests/st/compiler/compile_cache/run_lenet_with_jit.py @@ -0,0 +1,80 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, jit +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + + +class LeNet(nn.Cell): + def __init__(self): + super(LeNet, self).__init__() + self.relu = P.ReLU() + self.batch_size = 32 + + self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid', + weight_init="normal") + self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid', + weight_init="normal") + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.reshape = P.Reshape() + self.fc1 = nn.Dense(400, 120, weight_init="normal", bias_init="zeros") + self.fc2 = nn.Dense(120, 84, weight_init="normal", bias_init="zeros") + self.fc3 = nn.Dense(84, 10, weight_init="normal", bias_init="zeros") + + @jit + def construct(self, input_x): + output = self.conv1(input_x) + output = self.relu(output) + output = self.pool(output) + output = self.conv2(output) + output = self.relu(output) + output = self.pool(output) + output = self.reshape(output, (self.batch_size, -1)) + output = self.fc1(output) + output = self.relu(output) + output = self.fc2(output) + output = self.relu(output) + output = self.fc3(output) + return output + + +def train(net, data, label): + learning_rate = 0.01 + momentum = 0.9 + + optimizer = Momentum(filter(lambda x: x.requires_grad, + net.get_parameters()), learning_rate, momentum) + criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell( + net_with_criterion, optimizer) # optimizer + train_network.set_train() + res = train_network(data, label) + print("AAA", res, "BBB") + print("AAA", res.asnumpy().shape, "BBB") + + +if __name__ == "__main__": + context.set_context(mode=context.PYNATIVE_MODE) + input_data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + input_label = Tensor(np.ones([32]).astype(np.int32)) + lenet = LeNet() + train(lenet, input_data, input_label) + train(lenet, input_data, input_label) diff --git a/tests/st/compiler/compile_cache/test_compile_cache.py b/tests/st/compiler/compile_cache/test_compile_cache.py index 16419ee1c3a96d93f7a769e9af8c588ddd67f0eb..54ddc328e3de0713a915e92ac813c01e15ecb01d 100644 --- a/tests/st/compiler/compile_cache/test_compile_cache.py +++ b/tests/st/compiler/compile_cache/test_compile_cache.py @@ -32,7 +32,8 @@ match_num = re.compile(r'\d+\.?\d*', re.S) def exec_insert_command(regex, context, file_name): ret = os.system('sed -i "/{0}/{1}" {2}'.format(regex, context, file_name)) if ret != 0: - raise ValueError('exec `sed -i "/{0}/{1}" {2}` failed.'.format(regex, context, file_name)) + raise ValueError( + 'exec `sed -i "/{0}/{1}" {2}` failed.'.format(regex, context, file_name)) return ret @@ -53,7 +54,7 @@ def exec_cp_command(src, dst): def exec_model_and_check_result(cur_model_path, dataset_path, config_path, cache_path, check_context): exec_shell = f"export GLOG_v=2; export MS_COMPILER_CACHE_ENABLE=1; " \ + "export MS_COMPILER_CACHE_PATH={}; cd resnet/scripts; bash run_distribute_train.sh {} {} {}" \ - .format(cache_path, utils.rank_table_path, dataset_path, config_path) + .format(cache_path, utils.rank_table_path, dataset_path, config_path) os.system(exec_shell) cmd = "ps -ef | grep python | grep train.py | grep -v grep" ret = utils.process_check(100, cmd) @@ -76,7 +77,8 @@ def exec_model_and_check_result(cur_model_path, dataset_path, config_path, cache return loss -def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_file_name_second, is_debug=False): +def run_twice_with_same_network(file_name, cache_path, log_file_name_first, + log_file_name_second, is_debug=False, run_time=1): # Clear compile cache folder and log files if os.path.exists(cache_path): shutil.rmtree(cache_path) @@ -109,21 +111,24 @@ def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_ # Take out the result of the first run match_output_first = re.findall(match_output, data_first) - assert len(match_output_first) == 2 - nums_first = re.findall(match_num, match_output_first[0]) - array_first = np.array([float(x) for x in nums_first]) - shape_first = re.findall(match_num, match_output_first[1]) - array_shape_first = np.array([int(x) for x in shape_first]) + assert len(match_output_first) == 2 * run_time + array_first = [] + array_shape_first = [] + for i in range(run_time): + nums_first = re.findall(match_num, match_output_first[2 * i]) + array_first.append(np.array([float(x) for x in nums_first])) + shape_first = re.findall(match_num, match_output_first[2 * i + 1]) + array_shape_first.append(np.array([int(x) for x in shape_first])) # Second run with compile cache if not is_debug: cmd_second = f"export GLOG_v=2; export MS_COMPILER_CACHE_ENABLE=1; " \ - + "export MS_COMPILER_CACHE_PATH={}; python {} > {} 2>&1".format(cache_path, file_name, - log_file_name_second) + + "export MS_COMPILER_CACHE_PATH={}; python {} > {} 2>&1".format(cache_path, file_name, + log_file_name_second) else: cmd_second = f"export GLOG_v=0; export MS_COMPILER_CACHE_ENABLE=1; " \ - + "export MS_COMPILER_CACHE_PATH={}; python {} > {} 2>&1".format(cache_path, file_name, - log_file_name_second) + + "export MS_COMPILER_CACHE_PATH={}; python {} > {} 2>&1".format(cache_path, file_name, + log_file_name_second) subprocess.check_output(cmd_second, shell=True) assert os.path.exists(log_file_name_second) with open(log_file_name_second, "r") as f_second: @@ -139,14 +144,18 @@ def run_twice_with_same_network(file_name, cache_path, log_file_name_first, log_ # Take out the result of the second run match_output_second = re.findall(match_output, data_second) - assert len(match_output_second) == 2 - nums_second = re.findall(match_num, match_output_second[0]) - array_second = np.array([float(x) for x in nums_second]) - shape_second = re.findall(match_num, match_output_second[1]) - array_shape_second = np.array([int(x) for x in shape_second]) - - assert np.allclose(array_first, array_second, 0.0001, 0.0001) - assert (array_shape_first == array_shape_second).all() + assert len(match_output_second) == 2 * run_time + array_second = [] + array_shape_second = [] + for i in range(run_time): + nums_second = re.findall(match_num, match_output_second[2 * i]) + array_second.append(np.array([float(x) for x in nums_second])) + shape_second = re.findall(match_num, match_output_second[2 * i + 1]) + array_shape_second.append(np.array([int(x) for x in shape_second])) + + for i in range(run_time): + assert np.allclose(array_first[i], array_second[i], 0.0001, 0.0001) + assert (array_shape_first[i] == array_shape_second[i]).all() # Clean files os.remove(log_file_name_first) @@ -229,8 +238,10 @@ def clear_and_make_run_dir(dir_path): def check_compile_cache_files(cache_path, role): assert os.path.exists(cache_path) - assert os.path.exists(cache_path + "/rank_0/graph_cache/" + role + "compile_cache_0.mindir") - assert os.path.exists(cache_path + "/rank_0/graph_cache/" + role + "compile_dependency.hash") + assert os.path.exists( + cache_path + "/rank_0/graph_cache/" + role + "compile_cache_0.mindir") + assert os.path.exists( + cache_path + "/rank_0/graph_cache/" + role + "compile_dependency.hash") def run_network_once_with_force_use_compile_cache(file_name, cache_path, log_file_name_first): @@ -269,7 +280,8 @@ def test_compile_cache_load_weights(): Description: Test whether the compile cache can load the value of parameters successfully. Expectation: success. """ - run_twice_with_same_network("run_network_with_weights.py", "./weight", "weight_first.txt", "weight_second.txt") + run_twice_with_same_network( + "run_network_with_weights.py", "./weight", "weight_first.txt", "weight_second.txt") @arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') @@ -279,7 +291,8 @@ def test_compile_cache_lenet(): Description: Test whether the regular compile cache function can run successfully. Expectation: success. """ - run_twice_with_same_network("run_lenet.py", "./lenet", "lenet_first.txt", "lenet_second.txt", True) + run_twice_with_same_network( + "run_lenet.py", "./lenet", "lenet_first.txt", "lenet_second.txt", True) @arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') @@ -355,7 +368,8 @@ def test_compile_cache_run_two_cells_once(): Description: Test whether all the cells don't read the cached graph when run multiple cells once. Expectation: success. """ - run_two_cells_networks_once("run_lenet_two_cells.py", "./lenet_two_cells", "lenet_two_cells.txt") + run_two_cells_networks_once( + "run_lenet_two_cells.py", "./lenet_two_cells", "lenet_two_cells.txt") @arg_mark(plat_marks=['platform_gpu'], level_mark='level0', card_mark='onecard', essential_mark='essential') @@ -400,7 +414,6 @@ def test_compile_cache_with_inplace_tensor(): self.assignadd(k, ops.ones_like(k)) self.assignadd(v, ops.ones_like(v)) - kv_cache_shape = (None, 1) kv_cache_dtype = mstype.int32 dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) @@ -423,3 +436,47 @@ def test_compile_cache_with_inplace_tensor(): assert len(ms_compile_cache) == 1 assert kv_cache[0][0][0] == 3 assert kv_cache[1][0][0] == 3 + + +@arg_mark(plat_marks=['cpu_linux'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_compile_cache_grad_jit(): + """ + Feature: Compile cache. + Description: Test compile cache with grad jit. + Expectation: success. + """ + run_twice_with_same_network("run_compile_cache_grad_jit.py", "./compile_cache_grad_jit", + "compile_cache_grad_jit_first.txt", "compile_cache_grad_jit_second.txt") + + +@arg_mark(plat_marks=['cpu_linux'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_compile_cache_grad_jit_with_custom_bprop(): + """ + Feature: Compile cache. + Description: Test compile cache with grad jit. + Expectation: success. + """ + run_twice_with_same_network("run_compile_cache_custom_bprop.py", "./compile_cache_custom_bprop", + "compile_cache_custom_bprop_first.txt", "compile_cache_custom_bprop.txt") + + +@arg_mark(plat_marks=['platform_ascend'], level_mark='level1', card_mark='onecard', essential_mark='unessential') +def test_compile_cache_lenet_with_jit(): + """ + Feature: Compile cache. + Description: Test whether the regular compile cache function with jit can run successfully. + Expectation: success. + """ + run_twice_with_same_network( + "run_lenet_with_jit.py", "./lenet_with_jit", "lenet_with_jit_first.txt", "lenet_with_jit_second.txt", False, 2) + + +@arg_mark(plat_marks=['cpu_linux'], level_mark='level0', card_mark='onecard', essential_mark='essential') +def test_compile_cache_high_grad_jit(): + """ + Feature: Compile cache. + Description: Test compile cache with grad jit. + Expectation: success. + """ + run_twice_with_same_network("run_high_grad_jit.py", "./compile_cache_high_grad_jit", + "compile_cache_high_grad_jit_first.txt", "compile_cache_high_grad_jit_second.txt")