diff --git a/mindformers/tools/resume_ckpt.py b/mindformers/tools/resume_ckpt.py index 7a577178a6a1775abe52133a99cb3041ffb306f4..d3dd9c603a9b2a76caa31224e531d2df88eca92b 100644 --- a/mindformers/tools/resume_ckpt.py +++ b/mindformers/tools/resume_ckpt.py @@ -22,7 +22,6 @@ from mindformers.tools.utils import ( check_in_modelarts, check_ckpt_file_name, get_real_rank, - get_real_group_size, get_epoch_and_step_from_ckpt_name, get_times_epoch_and_step_from_ckpt_name, get_rank_id_from_ckpt_name, @@ -38,6 +37,8 @@ from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existe if check_in_modelarts(): import moxing as mox +else: + mox = None NO_META = "FOUND NO META.JSON" @@ -97,7 +98,7 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt', if last_epoch is None or last_step is None or last_ckpt_file is None: logger.info("No meta.json available and will use the checkpoints " "from the last timestamp for resume training.") - check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format) + check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format) create_file(latest_checkpointed_iteration_txt, NO_META) resume_ckpt = True else: @@ -135,7 +136,7 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list): health_ckpts = [] health_monitoring_json = os.path.join(health_ckpts_record_dir, 'health_ckpts.json') if os.path.exists(health_monitoring_json): - with open(health_monitoring_json, "r") as json_file: + with open(health_monitoring_json, "r", encoding="utf-8") as json_file: json_data = json.load(json_file) if not json_data: logger.warning(f"Get nothing from {json_file}.") @@ -165,7 +166,7 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id): if not check_in_modelarts(): if not os.path.exists(latest_checkpointed_iteration_txt): raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}") - with open(latest_checkpointed_iteration_txt, 'r') as f: + with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f: resume_info = [line.strip() for line in f.readlines()] else: if not mox.file.exists(latest_checkpointed_iteration_txt): @@ -199,7 +200,7 @@ def get_info_from_meta(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): if not os.path.exists(meta_json): logger.warning("%s is not found.", meta_json) continue - with open(meta_json, "r") as json_file: + with open(meta_json, "r", encoding='utf-8') as json_file: try: meta_data = json.load(json_file) if not meta_data: @@ -231,7 +232,7 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, epoch and step are consistent, and the path exists. """ # get all valid ckpts where the epoch and step values are not greater than those of last_ckpt_file. - ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-")] + ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-") + 1] last_epoch, last_step = get_epoch_and_step_from_ckpt_name(last_ckpt_file, ckpt_format) original_rank = get_rank_id_from_ckpt_name(last_ckpt_file) valid_ckpts = {} @@ -254,9 +255,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, # get ckpts suitable for resuming, where their rank numbers are intact, # epoch and step are consistent, and the path exists. resume_ckpt_list = [] - for key in valid_ckpts: - if check_checkpoints_by_rank(valid_ckpts[key], rank_dir_num): - ckpt_file = replace_rank_id_in_ckpt_name(valid_ckpts[key][0], rank_id) + for key, ckpts in valid_ckpts.items(): + if check_checkpoints_by_rank(ckpts, rank_dir_num): + ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id) resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file) if not os.path.exists(resume_ckpt): raise FileNotFoundError(f"{resume_ckpt} is not found!") @@ -311,14 +312,14 @@ def check_meta_info(epoch, step, ckpt_file, meta_json, ckpt_format='ckpt'): return True -def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): +def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): """ Verify that the prefix, epoch and step of the checkpoints from the last timestamp are equal across all rank folders in the checkpoint_dir directory. """ compared_checkpoint_name = None compared_original_checkpoint_name = None - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format) if not last_checkpoint: @@ -334,7 +335,7 @@ def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): return find_diff_ckpt = False - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir) diff --git a/mindformers/tools/utils.py b/mindformers/tools/utils.py index f250d48abfff645162d33c7953a01dccb101aa1d..520f51585208b8a78c15696d15615db84e0818cd 100644 --- a/mindformers/tools/utils.py +++ b/mindformers/tools/utils.py @@ -36,6 +36,8 @@ from mindspore import Tensor, context from mindspore._checkparam import args_type_check from mindspore.communication import get_group_size, get_rank, comm_func, get_local_rank +# pylint: disable=import-outside-toplevel + PARALLEL_MODE = {'DATA_PARALLEL': context.ParallelMode.DATA_PARALLEL, 'SEMI_AUTO_PARALLEL': context.ParallelMode.SEMI_AUTO_PARALLEL, 'AUTO_PARALLEL': context.ParallelMode.AUTO_PARALLEL, @@ -85,7 +87,7 @@ class Validator: def check_type(arg_value, arg_type): """Check int.""" if not isinstance(arg_value, arg_type): - raise TypeError('{} should be {} type, but get {}'.format(arg_value, arg_type, type(arg_value))) + raise TypeError(f'{arg_value} should be {arg_type} type, but get {type(arg_value)}') @staticmethod def is_obs_url(url): @@ -96,11 +98,11 @@ class Validator: def check_obs_url(url): """Check obs url.""" if not isinstance(url, str): - raise TypeError('remote_save_url type should be a str, but get {}, ' - 'please check your remote_save_url config'.format(type(url))) + raise TypeError(f'remote_save_url type should be a str, but get {type(url)}, ' + 'please check your remote_save_url config') if not (url.startswith(_PROTOCOL + '://') or url.startswith(_PROTOCOL_S3 + '://')): raise TypeError('remote_save_url should be start with obs:// or s3://, ' - 'but get {}, please check your remote_save_url config'.format(url)) + f'but get {url}, please check your remote_save_url config') def check_list(var_name: str, list_var: Union[Tuple, List], num: int): @@ -116,7 +118,7 @@ def check_list(var_name: str, list_var: Union[Tuple, List], num: int): """ for value in list_var: if value >= num: - raise ValueError('The index of the {} needs to be less than the number of nodes {}.'.format(var_name, num)) + raise ValueError(f'The index of the {var_name} needs to be less than the number of nodes {num}.') def check_file(file_path, file_type=None): @@ -192,7 +194,7 @@ def get_output_subpath(sub_class, rank_id=0, append_rank=True): root_path = get_output_root_path() directory = os.path.join(root_path, sub_class) if append_rank: - directory = os.path.join(directory, 'rank_{}'.format(rank_id)) + directory = os.path.join(directory, f'rank_{rank_id}') return format_path(directory) @@ -253,9 +255,9 @@ class Const: def __setattr__(self, key, value): if key in self.__dict__: - raise PermissionError('Can not change const {0}.'.format(key)) + raise PermissionError(f'Can not change const {key}.') if not key.isupper(): - raise ValueError('Const name {0} is not all uppercase.'.format(key)) + raise ValueError(f'Const name {key} is not all uppercase.') self.__dict__[key] = value @@ -330,7 +332,7 @@ def count_params(net): def try_sync_file(file_name): """If the file is still downloading, we need to wait before the file finished downloading""" if fcntl: - with open(file_name, 'r') as fp: + with open(file_name, 'r', encoding='utf-8') as fp: fcntl.flock(fp.fileno(), fcntl.LOCK_EX) @@ -382,16 +384,14 @@ def parse_value(value): b = int(a) except (TypeError, ValueError): return False - else: - return a == b + return a == b def isfloat(x): try: float(x) except (TypeError, ValueError): return False - else: - return True + return True def isbool(x): return x in ["True", "False"] @@ -401,8 +401,7 @@ def parse_value(value): json.loads(x) except json.decoder.JSONDecodeError: return False - else: - return True + return True if isint(value): return int(value) @@ -471,7 +470,7 @@ def get_dp_from_dataset_strategy(): first_input_stra = data_strategy[0] dp = int(first_input_stra[0]) else: - raise TypeError(f"Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") + raise TypeError("Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") return dp @@ -681,9 +680,32 @@ def replace_rank_id_in_ckpt_name(ckpt_file, dst_rank_id): return ckpt_name -def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None): - """clear transformed_checkpoint and strategy""" - folder_list = ["strategy", "transformed_checkpoint"] +def check_ckpt_format(load_ckpt_format='ckpt'): + """Check ckpt format""" + if load_ckpt_format not in ('ckpt', 'safetensors'): + raise ValueError(f"Invalid checkpoint format '{load_ckpt_format}'. " + "Only 'ckpt' and 'safetensors' formats are supported.") + + +def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None, load_ckpt_format='ckpt'): + """ + Clear transformed_checkpoint and strategy folders based on the specified checkpoint format. + + Args: + load_checkpoint (str, optional): Path to the checkpoint file/directory to load. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + src_strategy_path_or_dir (str, optional): Path to the source strategy file/directory. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + load_ckpt_format (str, optional): Format of the checkpoint file. Supports 'ckpt' and 'safetensors'. + Determines which transformed checkpoint folder to clear. Defaults to 'ckpt'. + + Raises: + ValueError: If the resolved folder path (strategy or transformed/unified checkpoint) is the same as + `load_checkpoint` or `src_strategy_path_or_dir`. This prevents overwriting critical input data. + """ + check_ckpt_format(load_ckpt_format) + folder_list = ["strategy", "transformed_checkpoint"] if load_ckpt_format == 'ckpt' \ + else ["strategy", "unified_checkpoint"] for folder in folder_list: if check_in_modelarts(): folder_path = os.path.join(get_remote_save_url(), folder) @@ -765,7 +787,7 @@ def get_pipeline_rank_ids(): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" if numerator % denominator != 0: - raise ValueError("{} is not divisible by {}".format(numerator, denominator)) + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator, denominator): diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 5f1dc053b378c734d12ea32c3b896f8132bfdcff..3a35544c6cb5a659f237ed5e6dd38ed261334fe5 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -63,6 +63,8 @@ from .build_trainer import build_trainer from .training_args import TrainingArguments from .utils import config2dict, get_last_checkpoint +# pylint: disable=import-outside-toplevel + __all__ = ['Trainer'] PREFIX_CHECKPOINT_DIR = "checkpoint" @@ -217,7 +219,7 @@ class Trainer: # check_task_and_model if self.task not in SUPPORT_TASKS.keys(): raise ValueError( - "The value of task must be in {}, but get {}".format(SUPPORT_TASKS.keys(), self.task)) + f"The value of task must be in {SUPPORT_TASKS.keys()}, but get {self.task}") if isinstance(self.model, (Cell, PreTrainedModel)): logger.info("The model instance has been entered, " @@ -230,7 +232,7 @@ class Trainer: else: self.is_model_instance = False if isinstance(self.model, str): - self.model = self.model + '_{}'.format(self.pet_method) if self.pet_method else self.model + self.model = f"{self.model}_{self.pet_method}" if self.pet_method else self.model if self.model not in SUPPORT_MODEL_NAMES: raise ValueError(f"model must be in {SUPPORT_MODEL_NAMES} " f"when model's type is string, but got {self.model}.") @@ -295,7 +297,7 @@ class Trainer: if (self.config.auto_trans_ckpt or self.config.resume_training) and \ check_in_modelarts() and self.config.remote_save_url: set_remote_save_url(self.config.remote_save_url) - logger.info(f"Set remote_save_url: %s, the output file will be uploaded to here.", + logger.info("Set remote_save_url: %s, the output file will be uploaded to here.", self.config.remote_save_url) # build trainer @@ -1010,7 +1012,7 @@ class Trainer: logger.warning("The `use_past` is set to False and reinit the incoming model.") model_config.parallel_config = self.config.parallel_config model_config.moe_config = self.config.moe_config - self.model.__init__(model_config) + self.model.__init__(model_config) # pylint: disable=unnecessary-dunder-call self.reset_model = False @staticmethod @@ -1038,7 +1040,7 @@ class Trainer: """get last checkpoint for resuming or finetune.""" output_folder = self.config.output_dir checkpoint_dir = os.path.join( - output_folder, DEFAULT_CHECKPOINT_DIR, 'rank_{}'.format(self.rank_id)) + output_folder, DEFAULT_CHECKPOINT_DIR, f'rank_{self.rank_id}') return get_last_checkpoint(checkpoint_dir, self.config.load_ckpt_format) @staticmethod @@ -1268,10 +1270,10 @@ class Trainer: config_dict = _reset_config_for_save(config) if config_dir is None: config_dir = os.path.join( - self.configs_directory, model_name.lower() + '_new') + self.configs_directory, f"{model_name.lower()}_new") if not os.path.exists(config_dir): os.makedirs(config_dir, exist_ok=True) - run_yaml_path = os.path.join(config_dir, 'run_{}.yaml'.format(model_name.lower())) + run_yaml_path = os.path.join(config_dir, f"run_{model_name.lower()}.yaml") _save_config_to_yaml(run_yaml_path, config_dict) @@ -1337,12 +1339,13 @@ class Trainer: f'does not contain any safetensors file and load_checkpoint is empty.' f'It will not load any weights.') - if self.config.auto_trans_ckpt and self.config.load_ckpt_format == 'ckpt': + if self.config.auto_trans_ckpt: if not is_publicly_accessible_path(get_output_root_path()): - raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True," - "the output_dir should be a shared directory that can be accessed by all nodes." - f"but {os.path.abspath(self.config.output_dir)} is not a shared directory.") - clear_auto_trans_output(self.config.load_checkpoint, self.config.src_strategy_path_or_dir) + raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True, " + f"the output_dir should be a shared directory that can be accessed by all nodes. " + f"But {os.path.abspath(self.config.output_dir)} is not a shared directory.") + clear_auto_trans_output( + self.config.load_checkpoint, self.config.src_strategy_path_or_dir, self.config.load_ckpt_format) if (self.config.auto_trans_ckpt or self.config.resume_training) and not self.config.load_checkpoint: if self.config.model and self.config.model.model_config.checkpoint_name_or_path: @@ -1367,10 +1370,14 @@ class Trainer: def _error_if_checkpoint_prefix_contains_rank_info(self): """Error if checkpoint prefix contains rank info""" - for callback in self.config.callbacks: - if "type" in callback and callback["type"] == "CheckpointMonitor": - if "rank" in callback.get("prefix", "mindformers"): - raise ValueError("The prefix for saving checkpoint is not allowed to contain 'rank'.") + if hasattr(self.config, "callbacks") and self.config.callbacks: + if not isinstance(self.config.callbacks, list): + raise ValueError("Expected 'callbacks' in config to be a list, " + f"but get {type(self.config.callbacks)}.") + for callback in self.config.callbacks: + if "type" in callback and callback["type"] == "CheckpointMonitor": + if "rank" in callback.get("prefix", "mindformers"): + raise ValueError("The prefix for saving checkpoint is not allowed to contain 'rank'.") def _check_args_task_and_model(self): """Check args, task and model.""" diff --git a/mindformers/trainer/utils.py b/mindformers/trainer/utils.py index f07218054d26ff066a6a003e8902862aa7d2a05e..a2cca068664c9e49166035d434f2334f1e27ceb5 100644 --- a/mindformers/trainer/utils.py +++ b/mindformers/trainer/utils.py @@ -44,6 +44,8 @@ from mindformers.models.base_model import BaseModel from mindformers.models.modeling_utils import PreTrainedModel from mindformers.version_control import need_nz +# pylint: disable=import-outside-toplevel + class BaseEnum(str, Enum): """ @@ -270,7 +272,7 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck "load_checkpoint should be a checkpoint directory containing the directory of rank_{0-*}," "The directory structure is as follows: **checkpoint_root_dir/rank_{0-*}/**.ckpt") rank_id = rank_id if rank_id is not None else get_real_rank() - distribute_checkpoint_dir = os.path.join(checkpoint_dir, "rank_{}".format(rank_id)) + distribute_checkpoint_dir = os.path.join(checkpoint_dir, f"rank_{rank_id}") distribute_checkpoint_path = get_last_checkpoint(distribute_checkpoint_dir, ckpt_format) logger.info("distribute checkpoint dir: %s", distribute_checkpoint_dir) elif os.path.isfile(checkpoint_dir): @@ -308,7 +310,9 @@ def load_resume_context_from_checkpoint(config, dataset): raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}") if os.path.isdir(config.load_checkpoint): - if config.use_graceful_exit: + # When graceful exit is enabled or auto checkpoint transformation is disabled, + # use real rank ID to locate the checkpoint. + if config.use_graceful_exit or not config.auto_trans_ckpt: rank_id = get_real_rank() else: rank_id = 0 @@ -472,7 +476,7 @@ def check_checkpoint_config_valid(config): def check_rank_folders(path, rank_id): """check if the folders in path are correct""" - folder_name = "rank_{}".format(rank_id) + folder_name = f"rank_{rank_id}" if not os.path.exists(os.path.join(path, folder_name)): return False return True @@ -534,14 +538,14 @@ def load_slora_ckpt(checkpoint_dict, config, network): adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json") if not os.path.exists(adapter_path): raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}") - with open(adapter_path, 'r') as file: + with open(adapter_path, 'r', encoding='utf-8') as file: path_dict = json.load(file) adapter_list = [] config_list = [] for adapter_name in network.lora_adapter.adapter_names[1:]: if adapter_name in path_dict.keys(): adapter_model = load_checkpoint(os.path.join(path_dict[adapter_name], "adapter_model.ckpt")) - with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r') as file: + with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r', encoding='utf-8') as file: adapter_config = json.load(file) else: adapter_model = {} diff --git a/mindformers/utils/load_checkpoint_utils.py b/mindformers/utils/load_checkpoint_utils.py index 66ed527b0d26b88a5c58c72e9cef7dbb8abf89e0..c8e4f1223a0e2ad395c97cecf40301ed5b0bef56 100644 --- a/mindformers/utils/load_checkpoint_utils.py +++ b/mindformers/utils/load_checkpoint_utils.py @@ -29,16 +29,14 @@ from mindspore.common.api import _pynative_executor from mindspore.communication.comm_func import barrier from mindformers.tools.logger import logger -from mindformers.tools.utils import ( - is_main_rank, - get_real_rank, - clear_auto_trans_output -) +from mindformers.tools.utils import is_main_rank, get_real_rank from mindformers.utils import convert_hf_safetensors_multiprocess, check_safetensors_key, is_hf_safetensors_dir from mindformers.utils.safetensors.convert_safetensors import _convert_index_json from mindformers.version_control import check_safetensors_addition_param_support from ..version_control import check_tft_valid +# pylint: disable=import-outside-toplevel + class CkptFormat(Enum): """ @@ -172,9 +170,9 @@ def extract_suffix(file_path): match = re.match(pattern, base_name) if not match: - logger.info(f"Filename '{filename}' does not match expected pattern. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' does not match expected pattern. " + f"Use '{base_name}' as the suffix.") + return base_name # Extract matched groups task_id = match.group(1) # Will be None if no task_id in filename @@ -182,9 +180,9 @@ def extract_suffix(file_path): step = match.group(3) if not epoch or not step: - logger.info(f"Filename '{filename}' is missing epoch or step information. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' is missing epoch or step information. " + f"Use '{base_name}' as the suffix.") + return base_name # Construct the appropriate suffix based on presence of task_id if task_id: @@ -194,25 +192,26 @@ def extract_suffix(file_path): def _get_src_file_suffix(config): """get file_suffix from config.load_checkpoint.""" - if isinstance(config.resume_training, str): - file_suffix = extract_suffix(config.resume_training) - return config.load_checkpoint, file_suffix - - if os.path.isfile(config.load_checkpoint): - # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} - file_suffix = extract_suffix(config.load_checkpoint) - checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) - return checkpoint_dir, file_suffix + checkpoint_dir = None + file_suffix = None + if is_main_rank(): + if isinstance(config.resume_training, str): + checkpoint_dir = config.load_checkpoint + checkpoint_name = config.resume_training.split('.')[0] + f'.{config.load_ckpt_format}' + elif os.path.isfile(config.load_checkpoint): + # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} + checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) + checkpoint_name = os.path.basename(config.load_checkpoint) + else: + checkpoint_dir = config.load_checkpoint + rank_path = f"{checkpoint_dir}/rank_0" + last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) + checkpoint_name = os.path.basename(last_checkpoint) - # config.load_checkpoint is folder - rank_id = get_real_rank() - rank_path = f"{config.load_checkpoint}/rank_{rank_id}" - if not os.path.exists(rank_path): - raise FileNotFoundError(f"{rank_path} not found.") + file_suffix = extract_suffix(checkpoint_name) + logger.info(f"checkpoint name suffix: {file_suffix}") - last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) - file_suffix = extract_suffix(last_checkpoint) - return config.load_checkpoint, file_suffix + return checkpoint_dir, file_suffix def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'): @@ -273,19 +272,19 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval load_checkpoint_files = [] strategy_path = ms.get_auto_parallel_context('strategy_ckpt_save_file') if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value: - logger.info(f"......Use single checkpoint file mode......") + logger.info("......Use single checkpoint file mode......") load_checkpoint_files = [config.load_checkpoint] if ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value: - logger.info(f"......Use multi checkpoint file mode......") + logger.info("......Use multi checkpoint file mode......") load_checkpoint_files = glob( os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() config.remove_redundancy = False elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value: - logger.info(f"......Use multi checkpoint file with rank id mode......") + logger.info("......Use multi checkpoint file with rank id mode......") # change strategy if config.auto_trans_ckpt: - logger.info(f"......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") + logger.info("......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") src_strategy_path = get_merged_src_strategy_path(config) unified_safetensors_path = os.path.join(config.output_dir, 'unified_checkpoint/') load_checkpoint, file_suffix = _get_src_file_suffix(config) @@ -300,7 +299,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() else: - logger.info(f"......auto_trans is False, will not unify or slice rank files......") + logger.info("......auto_trans is False, will not unify or slice rank files......") if check_tft_valid() and not config.remove_redundancy: sf_file_name = load_checkpoint logger.info(f"......tft is enabled and not enable remove_redundancy, sf_file_name={sf_file_name}......") @@ -316,7 +315,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval network = model._train_network #build model if config.use_parallel: - logger.info(f"......Start build model in parallel mode......") + logger.info("......Start build model in parallel mode......") build_model(config, model, input_data, do_eval=do_eval, do_predict=do_predict) #wait generate all rank strategy files barrier() @@ -337,7 +336,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval def process_for_stand_alone_mode(config, network, strategy_path): """process for stand alone mode""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and enable_stand_alone: from mindformers.parallel_core.inference.utils import generate_state_dict, save_strategy_file strategy_ckpt_save_dir = os.path.dirname(strategy_path) @@ -377,9 +376,9 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False, file_suffix=None, remove_redundancy=False): """merge strategy and unified safetensors.""" - logger.info("Start unify safetensors.") if is_main_rank(): # unify checkpoints + logger.info("Start unify safetensors.") logger.info(f"unified safetensors with file_suffix:{file_suffix}, remove_redundancy: {remove_redundancy}") logger.info(f"unified safetensors with save path:{unified_path}") unify_time_start = time.time() @@ -392,7 +391,10 @@ def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_paral ) unify_time_end = time.time() logger.info("Time spent unifying safetensors: %.2fs", unify_time_end - unify_time_start) - clear_auto_trans_output() + else: + logger.info("Wait for rank_0 to unify the safetensors, " + "please check the log of rank_0 to get the unify progress.") + if use_parallel: barrier() logger.info("Unified safetensors finished.") @@ -437,7 +439,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy ms.load_param_into_net(optimizer, hyper_param_dict) else: logger.info("......Start load checkpoint to model......") - params_dict = dict() + params_dict = {} remove_redundancy = config.get('remove_redundancy', False) for checkpoint_file in load_checkpoint_files: logger.info(f"load checkpoint file: {checkpoint_file}") @@ -614,7 +616,7 @@ def get_merged_src_strategy_path(config): def get_merged_dst_strategy_path(config, strategy_path): """prepare for dst strategy.""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and config.auto_trans_ckpt and not enable_stand_alone: # prepare merged strategy directory merged_strategy = os.path.join(config.output_dir, 'merged_strategy')