diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py index aa1894072154d2fb8a36f0cff8d56ad7f0ed020a..53bf260e95dad473fbb65e2177bc402613c2eae5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/graph_mode_cell_dump.py @@ -33,6 +33,12 @@ try: except ImportError: tensordump_flag = False +graph_step_flag = True +try: + from mindspore._c_expression import _dump_step +except ImportError: + graph_step_flag = False + class GraphModeCellDump: task = CoreConst.STATISTICS @@ -64,6 +70,15 @@ class GraphModeCellDump: _run_op(ops.TensorDump(), "TensorDump", (step_flag, temp_tensor)) ops.tensordump(step_flag, temp_tensor) + # 更新静态图KBK dump的step数 + if GraphModeCellDump.task == CoreConst.STATISTICS: + if not graph_step_flag: + raise Exception( + "Importing _dump_step failed, " + "please use the latest version package of MindSpore." + ) + _dump_step(1) + def check_config(self, strict): if not self.net: raise Exception("The model is empty and cell dump is not enabled.")