diff --git a/mindformers/tools/resume_ckpt.py b/mindformers/tools/resume_ckpt.py index 7a577178a6a1775abe52133a99cb3041ffb306f4..d3dd9c603a9b2a76caa31224e531d2df88eca92b 100644 --- a/mindformers/tools/resume_ckpt.py +++ b/mindformers/tools/resume_ckpt.py @@ -22,7 +22,6 @@ from mindformers.tools.utils import ( check_in_modelarts, check_ckpt_file_name, get_real_rank, - get_real_group_size, get_epoch_and_step_from_ckpt_name, get_times_epoch_and_step_from_ckpt_name, get_rank_id_from_ckpt_name, @@ -38,6 +37,8 @@ from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existe if check_in_modelarts(): import moxing as mox +else: + mox = None NO_META = "FOUND NO META.JSON" @@ -97,7 +98,7 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt', if last_epoch is None or last_step is None or last_ckpt_file is None: logger.info("No meta.json available and will use the checkpoints " "from the last timestamp for resume training.") - check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format) + check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format) create_file(latest_checkpointed_iteration_txt, NO_META) resume_ckpt = True else: @@ -135,7 +136,7 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list): health_ckpts = [] health_monitoring_json = os.path.join(health_ckpts_record_dir, 'health_ckpts.json') if os.path.exists(health_monitoring_json): - with open(health_monitoring_json, "r") as json_file: + with open(health_monitoring_json, "r", encoding="utf-8") as json_file: json_data = json.load(json_file) if not json_data: logger.warning(f"Get nothing from {json_file}.") @@ -165,7 +166,7 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id): if not check_in_modelarts(): if not os.path.exists(latest_checkpointed_iteration_txt): raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}") - with open(latest_checkpointed_iteration_txt, 'r') as f: + with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f: resume_info = [line.strip() for line in f.readlines()] else: if not mox.file.exists(latest_checkpointed_iteration_txt): @@ -199,7 +200,7 @@ def get_info_from_meta(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): if not os.path.exists(meta_json): logger.warning("%s is not found.", meta_json) continue - with open(meta_json, "r") as json_file: + with open(meta_json, "r", encoding='utf-8') as json_file: try: meta_data = json.load(json_file) if not meta_data: @@ -231,7 +232,7 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, epoch and step are consistent, and the path exists. """ # get all valid ckpts where the epoch and step values are not greater than those of last_ckpt_file. - ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-")] + ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-") + 1] last_epoch, last_step = get_epoch_and_step_from_ckpt_name(last_ckpt_file, ckpt_format) original_rank = get_rank_id_from_ckpt_name(last_ckpt_file) valid_ckpts = {} @@ -254,9 +255,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, # get ckpts suitable for resuming, where their rank numbers are intact, # epoch and step are consistent, and the path exists. resume_ckpt_list = [] - for key in valid_ckpts: - if check_checkpoints_by_rank(valid_ckpts[key], rank_dir_num): - ckpt_file = replace_rank_id_in_ckpt_name(valid_ckpts[key][0], rank_id) + for key, ckpts in valid_ckpts.items(): + if check_checkpoints_by_rank(ckpts, rank_dir_num): + ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id) resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file) if not os.path.exists(resume_ckpt): raise FileNotFoundError(f"{resume_ckpt} is not found!") @@ -311,14 +312,14 @@ def check_meta_info(epoch, step, ckpt_file, meta_json, ckpt_format='ckpt'): return True -def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): +def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): """ Verify that the prefix, epoch and step of the checkpoints from the last timestamp are equal across all rank folders in the checkpoint_dir directory. """ compared_checkpoint_name = None compared_original_checkpoint_name = None - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format) if not last_checkpoint: @@ -334,7 +335,7 @@ def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): return find_diff_ckpt = False - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir) diff --git a/mindformers/utils/__init__.py b/mindformers/utils/__init__.py index 4699a85432cade4298db410255ea1976e879bddf..801fe92793252fb3718dd31c5e409edc77cdc30b 100644 --- a/mindformers/utils/__init__.py +++ b/mindformers/utils/__init__.py @@ -27,5 +27,4 @@ from .safetensors import ( is_hf_safetensors_dir, check_safetensors_key, ) -from .load_checkpoint_utils import validate_qkv_concat, process_hf_checkpoint from .decorators import deprecated diff --git a/mindformers/utils/load_checkpoint_utils.py b/mindformers/utils/load_checkpoint_utils.py index bc00fdf25ed239bf3454a2dc8fe1bb0088b7379a..d23375f461eca27f26f23e5decc6d6a12c486cf9 100644 --- a/mindformers/utils/load_checkpoint_utils.py +++ b/mindformers/utils/load_checkpoint_utils.py @@ -38,6 +38,8 @@ from mindformers.utils import convert_hf_safetensors_multiprocess, check_safeten from mindformers.utils.safetensors.convert_safetensors import _convert_index_json from mindformers.checkpoint.utils import compile_model from mindformers.version_control import check_safetensors_addition_param_support +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.parallel_core.inference.utils import generate_state_dict, save_strategy_file from ..version_control import check_tft_valid @@ -173,9 +175,9 @@ def extract_suffix(file_path): match = re.match(pattern, base_name) if not match: - logger.info(f"Filename '{filename}' does not match expected pattern. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' does not match expected pattern. " + f"Use '{base_name}' as the suffix.") + return base_name # Extract matched groups task_id = match.group(1) # Will be None if no task_id in filename @@ -183,9 +185,9 @@ def extract_suffix(file_path): step = match.group(3) if not epoch or not step: - logger.info(f"Filename '{filename}' is missing epoch or step information. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' is missing epoch or step information. " + f"Use '{base_name}' as the suffix.") + return base_name # Construct the appropriate suffix based on presence of task_id if task_id: @@ -195,25 +197,26 @@ def extract_suffix(file_path): def _get_src_file_suffix(config): """get file_suffix from config.load_checkpoint.""" - if isinstance(config.resume_training, str): - file_suffix = extract_suffix(config.resume_training) - return config.load_checkpoint, file_suffix - - if os.path.isfile(config.load_checkpoint): - # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} - file_suffix = extract_suffix(config.load_checkpoint) - checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) - return checkpoint_dir, file_suffix + checkpoint_dir = None + file_suffix = None + if is_main_rank(): + if isinstance(config.resume_training, str): + checkpoint_dir = config.load_checkpoint + checkpoint_name = config.resume_training.split('.')[0] + f'.{config.load_ckpt_format}' + elif os.path.isfile(config.load_checkpoint): + # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} + checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) + checkpoint_name = os.path.basename(config.load_checkpoint) + else: + checkpoint_dir = config.load_checkpoint + rank_path = f"{checkpoint_dir}/rank_0" + last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) + checkpoint_name = os.path.basename(last_checkpoint) - # config.load_checkpoint is folder - rank_id = get_real_rank() - rank_path = f"{config.load_checkpoint}/rank_{rank_id}" - if not os.path.exists(rank_path): - raise FileNotFoundError(f"{rank_path} not found.") + file_suffix = extract_suffix(checkpoint_name) + logger.info(f"checkpoint name suffix: {file_suffix}") - last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) - file_suffix = extract_suffix(last_checkpoint) - return config.load_checkpoint, file_suffix + return checkpoint_dir, file_suffix def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'): @@ -274,19 +277,19 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval load_checkpoint_files = [] strategy_path = ms.get_auto_parallel_context('strategy_ckpt_save_file') if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value: - logger.info(f"......Use single checkpoint file mode......") + logger.info("......Use single checkpoint file mode......") load_checkpoint_files = [config.load_checkpoint] if ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value: - logger.info(f"......Use multi checkpoint file mode......") + logger.info("......Use multi checkpoint file mode......") load_checkpoint_files = glob( os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() config.remove_redundancy = False elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value: - logger.info(f"......Use multi checkpoint file with rank id mode......") + logger.info("......Use multi checkpoint file with rank id mode......") # change strategy if config.auto_trans_ckpt: - logger.info(f"......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") + logger.info("......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") src_strategy_path = get_merged_src_strategy_path(config) unified_safetensors_path = os.path.join(config.output_dir, 'unified_checkpoint/') load_checkpoint, file_suffix = _get_src_file_suffix(config) @@ -301,7 +304,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() else: - logger.info(f"......auto_trans is False, will not unify or slice rank files......") + logger.info("......auto_trans is False, will not unify or slice rank files......") if check_tft_valid() and not config.remove_redundancy: sf_file_name = load_checkpoint logger.info(f"......tft is enabled and not enable remove_redundancy, sf_file_name={sf_file_name}......") @@ -345,9 +348,8 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval def process_for_stand_alone_mode(config, network, strategy_path): """process for stand alone mode""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and enable_stand_alone: - from mindformers.parallel_core.inference.utils import generate_state_dict, save_strategy_file strategy_ckpt_save_dir = os.path.dirname(strategy_path) if is_main_rank(): if os.path.exists(strategy_ckpt_save_dir): @@ -385,9 +387,9 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False, file_suffix=None, remove_redundancy=False): """merge strategy and unified safetensors.""" - logger.info("Start unify safetensors.") if is_main_rank(): # unify checkpoints + logger.info("Start unify safetensors.") logger.info(f"unified safetensors with file_suffix:{file_suffix}, remove_redundancy: {remove_redundancy}") logger.info(f"unified safetensors with save path:{unified_path}") unify_time_start = time.time() @@ -401,6 +403,10 @@ def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_paral unify_time_end = time.time() logger.info("Time spent unifying safetensors: %.2fs", unify_time_end - unify_time_start) clear_auto_trans_output() + else: + logger.info("Wait for rank_0 to unify the safetensors, " + "please check the log of rank_0 to get the unify progress.") + if use_parallel: barrier() logger.info("Unified safetensors finished.") @@ -445,7 +451,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy ms.load_param_into_net(optimizer, hyper_param_dict) else: logger.info("......Start load checkpoint to model......") - params_dict = dict() + params_dict = {} remove_redundancy = config.get('remove_redundancy', False) for checkpoint_file in load_checkpoint_files: logger.info(f"load checkpoint file: {checkpoint_file}") @@ -540,7 +546,6 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin Currently only safetensors format is supported. """ # check the type of model_cls_or_instance - from mindformers.models.modeling_utils import PreTrainedModel if not ( isinstance(model_cls_or_instance, PreTrainedModel) or (isinstance(model_cls_or_instance, type) and issubclass(model_cls_or_instance, PreTrainedModel)) @@ -598,7 +603,7 @@ def get_merged_src_strategy_path(config): def get_merged_dst_strategy_path(config, strategy_path): """prepare for dst strategy.""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and config.auto_trans_ckpt and not enable_stand_alone: # prepare merged strategy directory merged_strategy = os.path.join(config.output_dir, 'merged_strategy') diff --git a/tests/st/test_safetensors/test_qkv_check.py b/tests/st/test_safetensors/test_qkv_check.py index 503dc83021b1974bc5fd6ea1379217e846de4198..10326f1ffda5de04b9a8be6ee20c53323bb1bb82 100644 --- a/tests/st/test_safetensors/test_qkv_check.py +++ b/tests/st/test_safetensors/test_qkv_check.py @@ -24,7 +24,7 @@ from safetensors.numpy import save_file from mindformers import LlamaForCausalLM, ChatGLM2Model from mindformers.tools.register import MindFormerConfig -from mindformers.utils import validate_qkv_concat +from mindformers.utils.load_checkpoint_utils import validate_qkv_concat class TestValidateQKVConcat: @@ -151,7 +151,7 @@ class TestValidateQKVConcat: assert "does not support qkv concat check" in log_content def _get_log_content(self): - with open(self.log_file_path, 'r') as log_file: + with open(self.log_file_path, 'r', encoding='utf-8') as log_file: log_content = log_file.read() return log_content