diff --git a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py index 2b17c4a96a2b33188bd55d08e5977e20173f793c..b0df6017ffd22bb5b40f25f76dead9f8d335e577 100644 --- a/debug/accuracy_tools/msprobe/core/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -171,8 +171,10 @@ class TensorComparisonBasic(abc.ABC): def get_relative_err(n_value, b_value): """计算相对误差""" with np.errstate(divide='ignore', invalid='ignore'): + if n_value.dtype not in CompareConst.FLOAT_TYPE: + n_value = n_value.astype(float) if b_value.dtype not in CompareConst.FLOAT_TYPE: - n_value, b_value = n_value.astype(float), b_value.astype(float) + b_value = b_value.astype(float) n_value_copy = n_value.copy() b_value_copy = b_value.copy()