diff --git a/torch_npu/csrc/profiler/npu_profiler.cpp b/torch_npu/csrc/profiler/npu_profiler.cpp index d369bb20d3285df9efdb05e85db5e23f887cfc22..658bbffafa485e622dc2e0ca23b3412513b7f577 100644 --- a/torch_npu/csrc/profiler/npu_profiler.cpp +++ b/torch_npu/csrc/profiler/npu_profiler.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include @@ -29,33 +30,42 @@ void call(Command c) } } // python_tracer using torch_npu::toolkit::profiler::Utils; +using torch::autograd::profiler::ProfilerConfig; +using torch::autograd::profiler::ProfilerState; +using torch::profiler::impl::ProfilerStateBase; +using torch::profiler::impl::ActiveProfilerType; struct NpuObserverContext : public at::ObserverContext { explicit NpuObserverContext(std::unique_ptr data) : data_(std::move(data)) {} std::unique_ptr data_; }; -struct NpuProfilerThreadLocalState : public c10::MemoryReportingInfoBase { - explicit NpuProfilerThreadLocalState( - const NpuProfilerConfig &config, - std::set activities) - : config_(config), - activities_(std::move(activities)) {} - ~NpuProfilerThreadLocalState() override = default; - - static NpuProfilerThreadLocalState *getTLS() { - return static_cast( - c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE) - ); - } +struct NpuProfilerThreadLocalState : public ProfilerStateBase { + explicit NpuProfilerThreadLocalState( + const NpuProfilerConfig &config, + std::set activities) + : ProfilerStateBase(ProfilerConfig(ProfilerState::CPU)), npu_config_(config), activities_(std::move(activities)) + { + // copy value from NpuProfilerConfig to ProfilerConfig for compatibility + config_.report_input_shapes = config.record_shapes; + config_.profile_memory = config.profile_memory; + config_.with_stack = config.with_stack; + config_.with_flops = config.with_flops; + config_.with_modules = config.with_modules; + } + ~NpuProfilerThreadLocalState() override = default; - const NpuProfilerConfig &config() const { - return config_; - } + static NpuProfilerThreadLocalState *getTLS() + { + return static_cast( + c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PROFILER_STATE) + ); + } - const std::set &activities() const { - return activities_; - } + const std::set &activities() const + { + return activities_; + } std::unique_ptr newOpEvent(const at::RecordFunction &fn) { @@ -73,25 +83,30 @@ struct NpuProfilerThreadLocalState : public c10::MemoryReportingInfoBase { ); } - bool memoryProfilingEnabled() const { - return config_.profile_memory; - } + bool memoryProfilingEnabled() const + { + return config_.profile_memory; + } - bool tracePython() { - return (config_.with_stack || config_.with_modules) && activities_.count(NpuActivityType::CPU); - } + bool tracePython() + { + return (config_.with_stack || config_.with_modules) && activities_.count(NpuActivityType::CPU); + } - void setCallbackHandle(at::CallbackHandle handle) { - handle_ = handle; - } + void setCallbackHandle(at::CallbackHandle handle) + { + handle_ = handle; + } - at::CallbackHandle callbackHandle() const { - return handle_; - } + at::CallbackHandle callbackHandle() const + { + return handle_; + } - bool hasCallbackHandle() { - return handle_ > 0; - } + bool hasCallbackHandle() + { + return handle_ > 0; + } // Only CPU void reportMemoryUsage( @@ -120,10 +135,15 @@ struct NpuProfilerThreadLocalState : public c10::MemoryReportingInfoBase { } } + ActiveProfilerType profilerType() override + { + return ActiveProfilerType::NONE; + } + protected: - NpuProfilerConfig config_; - std::set activities_; - at::CallbackHandle handle_ = 0; + NpuProfilerConfig npu_config_; + std::set activities_; + at::CallbackHandle handle_ = 0; }; std::atomic& profDataReportEnable() @@ -131,30 +151,31 @@ std::atomic& profDataReportEnable() return ProfilerMgr::GetInstance()->ReportEnable(); } -void initNpuProfiler(const std::string &path, const std::set &activities) { - if (path.empty()) { - return; - } - std::string absPath = Utils::RelativeToAbsPath(path); - if (Utils::IsSoftLink(absPath)) { - ASCEND_LOGE("Path %s is soft link.", absPath.c_str()); - return; - } - if (!Utils::IsFileExist(absPath) && !Utils::CreateDir(absPath)) { - ASCEND_LOGE("Path %s not exist and create failed.", absPath.c_str()); - return; - } - if (!Utils::IsDir(absPath) || !Utils::IsFileWritable(absPath)) { - ASCEND_LOGE("%s is not a directory or is not writable.", absPath.c_str()); - return; - } - bool npu_trace = false; - if (activities.count(NpuActivityType::NPU)) { - npu_trace = true; - } - std::string realPath = Utils::RealPath(absPath); - TORCH_CHECK(!realPath.empty(), "Invalid path", path, PROF_ERROR(ErrCode::PARAM)); - ProfilerMgr::GetInstance()->Init(realPath, npu_trace); +void initNpuProfiler(const std::string &path, const std::set &activities) +{ + if (path.empty()) { + return; + } + std::string absPath = Utils::RelativeToAbsPath(path); + if (Utils::IsSoftLink(absPath)) { + ASCEND_LOGE("Path %s is soft link.", absPath.c_str()); + return; + } + if (!Utils::IsFileExist(absPath) && !Utils::CreateDir(absPath)) { + ASCEND_LOGE("Path %s not exist and create failed.", absPath.c_str()); + return; + } + if (!Utils::IsDir(absPath) || !Utils::IsFileWritable(absPath)) { + ASCEND_LOGE("%s is not a directory or is not writable.", absPath.c_str()); + return; + } + bool npu_trace = false; + if (activities.count(NpuActivityType::NPU)) { + npu_trace = true; + } + std::string realPath = Utils::RealPath(absPath); + TORCH_CHECK(!realPath.empty(), "Invalid path", path, PROF_ERROR(ErrCode::PARAM)); + ProfilerMgr::GetInstance()->Init(realPath, npu_trace); } static void parseInputShapesAndDtypes(const at::RecordFunction &fn, @@ -202,7 +223,7 @@ static void registerCallback(const std::unordered_set &scopes) const auto &config = state_ptr->config(); auto ctx_ptr = state_ptr->newOpEvent(fn); auto &data_ptr = ctx_ptr->data_; - if ((C10_UNLIKELY(config.record_shapes))) { + if ((C10_UNLIKELY(config.report_input_shapes))) { parseInputShapesAndDtypes(fn, data_ptr->input_dtypes, data_ptr->input_shapes); } if (C10_UNLIKELY(config.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION)) { @@ -230,7 +251,7 @@ static void registerCallback(const std::unordered_set &scopes) } } ) - .needsInputs(registeration_state_ptr->config().record_shapes) + .needsInputs(registeration_state_ptr->config().report_input_shapes) .scopes(scopes) ); registeration_state_ptr->setCallbackHandle(handle);