From 830981be7e8f4b7953119f17131f26eb078013a7 Mon Sep 17 00:00:00 2001 From: zyw_hw Date: Wed, 5 Nov 2025 16:13:56 +0800 Subject: [PATCH] reboot node skip load ckpt --- mindformers/trainer/base_trainer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 0bed5b38b..ab34d9ed1 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -81,6 +81,7 @@ from .utils import set_seed, check_train_data_loader_type, \ check_eval_data_loader_type, check_optimizer_and_lr_type, check_wrapper_config from ..version_control import check_tft_valid, check_tre_valid, check_tsp_valid, check_is_reboot_node +# pylint: disable=import-outside-toplevel SUPPORT_TASKS = MindFormerBook().get_trainer_support_task_list() SUPPORT_MODEL_NAMES = MindFormerBook().get_model_name_support_list() SUPPORT_PIPELINES = MindFormerBook().get_pipeline_support_task_list() @@ -105,12 +106,12 @@ class BaseTrainer: def __init__(self, task: str = None, model_name: str = None): host_name_output = subprocess.run(['hostname'], shell=False, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, encoding='utf-8') + stderr=subprocess.PIPE, encoding='utf-8', check=True) host_ip_output = subprocess.run(['hostname', '-I'], shell=False, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, encoding='utf-8') + stderr=subprocess.PIPE, encoding='utf-8', check=True) host_name = host_name_output.stdout.strip() host_ip = host_ip_output.stdout.strip().split(' ')[0] - logger.info("host_name: %s, host_ip: %s" % (host_name, host_ip)) + logger.info(f"host_name: {host_name}, host_ip: {host_ip}") if model_name is None: model_name = "model name unspecified." @@ -903,7 +904,7 @@ class BaseTrainer: Postprocess the training dataset after construction. Mainly used to adjust dataset size for special dataloaders. """ - dataloader_config = config.train_dataset.get('data_loader', dict()) + dataloader_config = config.train_dataset.get('data_loader', {}) dataloader_type = dataloader_config.get('type') # Special handling for BlendedMegatronDatasetDataLoader @@ -926,7 +927,7 @@ class BaseTrainer: # Check dataset sink mode with dataset broadcast optimization level self._check_sink_mode_with_ds_broadcast(config) - dataloader_config = config.train_dataset.get('data_loader', dict()) + dataloader_config = config.train_dataset.get('data_loader', {}) dataloader_type = dataloader_config.get('type') # Case 1: BlendedMegatronDatasetDataLoader @@ -990,7 +991,7 @@ class BaseTrainer: """ cur_rank = get_rank() - src_strategy_files = sorted([f for f in os.listdir(config.src_strategy_path_or_dir)]) + src_strategy_files = sorted(list(os.listdir(config.src_strategy_path_or_dir))) if len(src_strategy_files) - 1 < cur_rank: raise ValueError(f" rank {cur_rank} src_strategy is not exist") src_strategy_file = os.path.join(config.src_strategy_path_or_dir, src_strategy_files[cur_rank]) @@ -1062,7 +1063,7 @@ class BaseTrainer: append_info = None if not config.ckpt_use_legacy_format: - if config.resume_training and config.load_checkpoint: + if config.resume_training and config.load_checkpoint and not check_is_reboot_node(): logger.info(".............Start load resume context from common.json..................") common_file = os.path.join(config.load_checkpoint, 'common.json') if not os.path.exists(common_file): @@ -1094,15 +1095,15 @@ class BaseTrainer: config.runner_config.initial_epoch = 0 config.runner_config.initial_step = 0 else: - if config.resume_training and config.load_checkpoint: + if config.resume_training and config.load_checkpoint and not check_is_reboot_node(): logger.info(".............Start load resume context from checkpoint..................") if check_tft_valid() and not config.remove_redundancy: logger.info("..............Start resume checkpoint path from strategy..............") resume_ckpt_path = self.resume_ckpt_path_with_strategy(config) if resume_ckpt_path is None: - raise ValueError("Try to resume from checkpoints with strategy in directory '{}' failed, " - "please specify load_checkpoint to specific checkpoint file to resume training." - .format(config.load_checkpoint)) + raise ValueError(f"Try to resume from checkpoints with strategy in directory " + f"'{config.load_checkpoint}' failed, please specify load_checkpoint to " + f"specific checkpoint file to resume training.") config.load_checkpoint = resume_ckpt_path load_resume_context_from_checkpoint(config, dataset) resume_dict = { -- Gitee