diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index d14ec82880207d01123a6ce4f185468eed4526f4..17859e8f46c368ad0eb9391b2e2226267bb38680 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -43,6 +43,7 @@ from mindspore import ( set_auto_parallel_context ) from mindspore.train.callback import SummaryCollector +from mindspore.train.callback._checkpoint import CheckpointManager from mindspore.nn.learning_rate_schedule import LearningRateSchedule from mindspore.train.serialization import _get_merged_param_data from mindspore.nn.cell import Cell @@ -1178,6 +1179,14 @@ class CheckpointMonitor(ModelCheckpoint): self.last_ckpoint_file = None self.meta_updated = True + if self.save_network_params: + self._network_manager = CheckpointManager(config_ck.format) + + if self.save_trainable_params: + self._trainable_manager = CheckpointManager(config_ck.format) + + self.need_remove_extra_ckpt = False + def print_savetime(self, record_step, batch_num): """print the time cost of saving checkpoint files.""" epoch = int((record_step - 1) // batch_num + 1) @@ -1243,6 +1252,7 @@ class CheckpointMonitor(ModelCheckpoint): # keep checkpoint files number equal max number. if self._config.keep_checkpoint_max and 0 < self._config.keep_checkpoint_max <= self._manager.ckpoint_num: self._manager.remove_oldest_ckpoint_file() + self.need_remove_extra_ckpt = True elif self._config.keep_checkpoint_per_n_minutes and self._config.keep_checkpoint_per_n_minutes > 0: # pylint: disable=E0203 self._cur_time_for_keep = time.time() @@ -1443,6 +1453,15 @@ class CheckpointMonitor(ModelCheckpoint): f"_{str(step_num_in_epoch)}.{self._config.format}") cb_cur_file = os.path.join(self.trainable_directory, cb_cur_ckpoint_file) os.makedirs(self.trainable_directory, exist_ok=True) + + # update checkpoint file list. + self._trainable_manager.update_ckpoint_filelist( + self.trainable_directory, f"{self._prefix}-trainable_params" + ) + # keep checkpoint files number equal max number. + if self.need_remove_extra_ckpt: + self._trainable_manager.remove_oldest_ckpoint_file() + self.remove_redundancy(save_obj, cb_cur_file, {}, network) self.save_info_list[cb_params.cur_step_num]['trainable_params']['ckpt_file_path'] = cb_cur_file return @@ -1453,9 +1472,18 @@ class CheckpointMonitor(ModelCheckpoint): f"_{str(step_num_in_epoch)}.{self._config.format}") cb_cur_file = os.path.join(self.network_directory, cb_cur_ckpoint_file) os.makedirs(self.network_directory, exist_ok=True) + + # update checkpoint file list. + self._network_manager.update_ckpoint_filelist(self.network_directory, f"{self._prefix}-network") + # keep checkpoint files number equal max number. + if self.need_remove_extra_ckpt: + self._network_manager.remove_oldest_ckpoint_file() + self.remove_redundancy(save_obj, cb_cur_file, {}, network) self.save_info_list[cb_params.cur_step_num]['network']['ckpt_file_path'] = cb_cur_file + self.need_remove_extra_ckpt = False + def record_last_ckpt_to_json(self, epoch, step, ckpt_file):