From e7144d388b70bd67c7c2302574a10adb93d54b6c Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Fri, 8 Aug 2025 15:47:35 +0800 Subject: [PATCH] compare get_relative_err dtype error bugfix --- debug/accuracy_tools/msprobe/core/compare/npy_compare.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py index 2b17c4a96..b0df6017f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -171,8 +171,10 @@ class TensorComparisonBasic(abc.ABC): def get_relative_err(n_value, b_value): """计算相对误差""" with np.errstate(divide='ignore', invalid='ignore'): + if n_value.dtype not in CompareConst.FLOAT_TYPE: + n_value = n_value.astype(float) if b_value.dtype not in CompareConst.FLOAT_TYPE: - n_value, b_value = n_value.astype(float), b_value.astype(float) + b_value = b_value.astype(float) n_value_copy = n_value.copy() b_value_copy = b_value.copy() -- Gitee