diff --git a/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py b/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py index 9eb6da2816de4997f16af071fbf1766879e4cdab..54d7211fc690a85a5bea5267daf9062bc2eddc72 100644 --- a/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py +++ b/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py @@ -44,11 +44,13 @@ class TimeStatistics: def __enter__(self): if self.debug: self.time = datetime.now() + logger_debug(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \ + f'Id[{self.index}]') def __exit__(self, exc_type, exc_val, exc_tb): if self.debug: cost_time = datetime.now() - self.time - time_cost = f'Time[{self.tag}]: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \ + time_cost = f'Time[{self.tag}]-EXIT: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \ f'Id[{self.index}], time[{cost_time}]' hot_time_cost = "Hotspot " + time_cost diff --git a/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py b/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py index 149eca3c2d32b92a96f1448c76136451cf4c614a..aa6874a044831bf809c011ac6805bcfaf45a6962 100644 --- a/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py +++ b/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py @@ -4,6 +4,14 @@ import logging import psutil import torch import numpy as np + +try: + import torch_npu +except ImportError: + pta_cpu_device = None +else: + pta_cpu_device = torch.device("cpu") + from ..common.utils import print_error_log, CompareConst cpu_device = torch._C.device("cpu") @@ -72,7 +80,7 @@ def data_to_cpu(data, deep, data_cpu): global cpu_device list_cpu = [] if isinstance(data, torch.Tensor): - if data.device == cpu_device: + if data.device == cpu_device or data.device == pta_cpu_device: tensor_copy = data.clone().detach() else: tensor_copy = data.cpu()