diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 566f5c6de75a3b441f0f472eb81459ff05ba3f6c..39675f492885520cd354739eb646fd7f31eea2ec 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -111,27 +111,28 @@ class PytorchDataProcessor(BaseDataProcessor): data_clone = data.detach() if not data_clone.numel() or not data_clone.data_ptr(): return tensor_stat - if torch.is_complex(data): + if torch.is_complex(data_clone): if async_dump: logger.warning("Async dump do not support complex data!") return tensor_stat - data_np = data.cpu().numpy() + data_np = data_clone.cpu().numpy() data_abs = np.abs(data_np) tensor_stat.max = np.max(data_abs).item() tensor_stat.min = np.min(data_abs).item() tensor_stat.mean = np.mean(data_abs).item() - elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data) - tensor_stat.min = torch.all(data) - elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.clone() + elif data_clone.dtype == torch.bool: + tensor_stat.max = torch.any(data_clone) + tensor_stat.min = torch.all(data_clone) + elif not data_clone.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.clone() else: - if precision == Const.DUMP_PRECISION_HIGH or data.dtype == torch.float64 or not data.is_floating_point(): - data = data.float() - tensor_stat.max = torch.max(data) - tensor_stat.min = torch.min(data) - tensor_stat.mean = torch.mean(data) - tensor_stat.norm = torch.norm(data) + if (precision == Const.DUMP_PRECISION_HIGH or data_clone.dtype == torch.float64 + or not data_clone.is_floating_point()): + data_clone = data_clone.float() + tensor_stat.max = torch.max(data_clone) + tensor_stat.min = torch.min(data_clone) + tensor_stat.mean = torch.mean(data_clone) + tensor_stat.norm = torch.norm(data_clone) return tensor_stat @staticmethod