diff --git a/ptdbg_ascend/test/ut/compare/test_compare.py b/ptdbg_ascend/test/ut/compare/test_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0c6f8ab5a9a17148d7ac8f4b732d0009e15129 --- /dev/null +++ b/ptdbg_ascend/test/ut/compare/test_compare.py @@ -0,0 +1,46 @@ +import os +import unittest +from ptdbg_ascend.compare.acc_compare import compare_by_op, read_dump_path +from ptdbg_ascend.common.utils import CompareConst, CompareException, check_file_or_directory_path, \ + check_compare_param + + +class TestCompareFunctions(unittest.TestCase): + def test_compare_by_op_when_dump_file_not_found(self): + op_name = "Functional_conv2d_10_forward" + op_name_mapping_dict = { + 'Functional_conv2d_10_forward':[ + 'Functional_conv2d_10_forward.input.0', 'Functional_conv2d_10_forward.input.1' + ] + } + input_parma = {'npu_dump_data_dir':'npu_dump', 'bench_dump_data_dir':'gpu_dump'} + cos_sim, max_abs_err, err_msg = compare_by_op(op_name, op_name_mapping_dict, input_parma) + self.assertEqual(cos_sim, CompareConst.NAN) + self.assertEqual(max_abs_err, CompareConst.NAN) + self.assertIn('Dump file', err_msg) + self.assertIn('not found', err_msg) + + def test_read_dump_path_when_path_invalid(self): + noop_path = 'noop.path' + self.assertRaises(FileNotFoundError, read_dump_path, noop_path) + + def test_check_file_or_directory_path_when_file_is_invalid(self): + npu_pkl, npu_dump = 'noop.path', '../resources/compare/npu_dump.pkl' + self.assertRaises(CompareException, check_file_or_directory_path, npu_pkl, False) + self.assertRaises(CompareException, check_file_or_directory_path, npu_dump, True) + + def test_check_compare_param_when_param_invalid(self): + self.assertRaises(CompareException, check_compare_param, input_parma=['not a dict'], + output_path='', stack_mode=True, auto_analyze=False, fuzzy_match=False) + self.assertRaises(CompareException, check_compare_param, input_parma={}, + output_path=233, stack_mode=True, auto_analyze=False, fuzzy_match=False) + self.assertRaises(CompareException, check_compare_param, input_parma={}, + output_path='', stack_mode=3, auto_analyze=False, fuzzy_match=False) + self.assertRaises(CompareException, check_compare_param, input_parma={}, + output_path='', stack_mode=True, auto_analyze=[], fuzzy_match=False) + self.assertRaises(CompareException, check_compare_param, input_parma={}, + output_path='', stack_mode=True, auto_analyze=False, fuzzy_match='') + +if __name__ == '__main__': + unittest.main() + \ No newline at end of file