diff --git a/mindformers/checkpoint/checkpoint.py b/mindformers/checkpoint/checkpoint.py index 20ec0d221ee17a6302ffae4ea3977e26fe96b057..1f49ea6a1ad71692963308505366f69c48514d3f 100644 --- a/mindformers/checkpoint/checkpoint.py +++ b/mindformers/checkpoint/checkpoint.py @@ -31,17 +31,18 @@ from mindspore.common import dtype as mstype from mindspore.nn import Cell from mindspore.nn.optim.optimizer import Optimizer from mindspore.communication.management import get_rank, get_group_size -import mindspore.communication.comm_func as comm_func +from mindspore.communication import comm_func from mindspore import save_checkpoint as ms_save_checkpoint +from mindspore.parallel.strategy import get_current_strategy_metadata from mindformers.tools.logger import logger from mindformers.checkpoint.reshard import ReshardHandler +from mindformers.utils.file_utils import is_publicly_accessible_path +from mindformers.utils.parallel_utils import barrier_world from mindformers.tools.utils import ( - barrier_world, get_output_subpath, get_real_rank, set_safe_mode_for_file_or_dir, - is_publicly_accessible_path, ) from mindformers.checkpoint.utils import ( get_checkpoint_iter_dir, @@ -319,7 +320,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, """Save checkpoint finalize function.""" tracker_filename = get_checkpoint_tracker_filename(checkpoints_root_path) logger.info(f"save checkpoint tracker file to {tracker_filename}") - with open(tracker_filename, "w") as f: + with open(tracker_filename, "w", encoding='utf-8') as f: f.write(str(iteration)) set_safe_mode_for_file_or_dir(tracker_filename) if use_async_save: @@ -861,8 +862,7 @@ def load_checkpoint( ) # Get current strategy metadata from network and optimizer - logger.info(f".........Get Current Strategy Metadata.........") - from mindspore.parallel.strategy import get_current_strategy_metadata + logger.info(".........Get Current Strategy Metadata.........") cur_rank_strategy_layout = get_current_strategy_metadata(network=network)[0] cur_rank_sharded_tensors: List[ShardedTensor] = [] diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index b30a9d2fdbb543fe2ec1ec95982b10ab09e3ae50..e3bb0cbbbd925a905ffa349593cf255cce5d0397 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -62,6 +62,7 @@ from mindformers.core.context.build_context import is_legacy_model from mindformers.tools import get_output_root_path from mindformers.tools.logger import logger from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.utils.parallel_utils import barrier_world from mindformers.tools.utils import ( get_output_subpath, get_real_rank, @@ -69,7 +70,6 @@ from mindformers.tools.utils import ( get_real_local_rank, get_pipeline_rank_ids, is_last_pipeline_stage, - barrier_world, get_ascend_log_path, set_safe_mode_for_file_or_dir ) diff --git a/mindformers/core/context/parallel.py b/mindformers/core/context/parallel.py index fb92fbb92a330b0501cf1fc10f4ec32281771446..9517d6e819086bf7dc3a66a39219d59f6c533bc9 100644 --- a/mindformers/core/context/parallel.py +++ b/mindformers/core/context/parallel.py @@ -24,7 +24,7 @@ from mindformers.modules.transformer.transformer import ( TransformerOpParallelConfig, ) from mindformers.tools.logger import logger -from mindformers.tools.utils import set_strategy_save_path +from mindformers.utils.file_utils import set_strategy_save_path from mindformers.parallel_core.inference.parallel_state import initialize_model_parallel diff --git a/mindformers/dataset/dataloader/blended_megatron_dataloader.py b/mindformers/dataset/dataloader/blended_megatron_dataloader.py index f4cf71a08295cad215f48ff6dd12184f2e813e39..f8bf6d585f093cd0527b19c484ca5f77dac1f47e 100644 --- a/mindformers/dataset/dataloader/blended_megatron_dataloader.py +++ b/mindformers/dataset/dataloader/blended_megatron_dataloader.py @@ -28,12 +28,12 @@ from mindformers.models.build_tokenizer import build_tokenizer from mindformers.tools.logger import logger from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister from mindformers.version_control import skip_barrier_controller +from mindformers.utils.file_utils import is_publicly_accessible_path from mindformers.tools.utils import ( get_dp_from_dataset_strategy, get_real_group_size, get_real_rank, - get_real_local_rank, - is_publicly_accessible_path + get_real_local_rank ) from .utils import is_dataset_built_on_rank diff --git a/mindformers/dataset/dataloader/multi_source_dataloader.py b/mindformers/dataset/dataloader/multi_source_dataloader.py index 4e62762385465c654e96c8713964dea556540566..a5abf5baa1a0271ab720f7ef30090b2522a7be00 100644 --- a/mindformers/dataset/dataloader/multi_source_dataloader.py +++ b/mindformers/dataset/dataloader/multi_source_dataloader.py @@ -23,10 +23,12 @@ from tqdm import tqdm from mindspore.dataset import Dataset, GeneratorDataset, Shuffle +from mindformers.utils.file_utils import is_publicly_accessible_path + from .build_dataloader import build_dataset_loader from ...tools.logger import logger from ...tools.register import MindFormerRegister, MindFormerModuleType -from ...tools.utils import get_real_rank, is_publicly_accessible_path, get_device_num_per_node +from ...tools.utils import get_real_rank, get_device_num_per_node @MindFormerRegister.register(MindFormerModuleType.DATASET_LOADER) @@ -49,21 +51,7 @@ class MultiSourceDataLoader: shard_id = kwargs.pop("shard_id", None) if dataset_ratios is not None: - if any([ratios < 0 for ratios in dataset_ratios]): - raise ValueError( - f"the dataset_ratios should be a list of positive value, but got {dataset_ratios}") - - if abs(sum(dataset_ratios) - 1) > 1e-5: - raise ValueError("the sum of ratios is not equals to 1") - - if samples_count is None or samples_count <= 0: - raise ValueError(f"the samples_count should be a positive int when dataset_ratios is not None, " - f"but got {samples_count}") - - if not isinstance(dataset_ratios, list) or len(dataset_ratios) != len(sub_data_loader): - raise ValueError( - "the dataset_ratios should be a list and the length should equal to sub_data_loader") - + validate_dataset_ratios(dataset_ratios, sub_data_loader, samples_count) nums_per_dataset = [int(ratio * samples_count) for ratio in dataset_ratios] need_sample = True else: @@ -71,35 +59,13 @@ class MultiSourceDataLoader: nums_per_dataset = [] need_sample = False else: - if not isinstance(nums_per_dataset, list) or len(nums_per_dataset) != len(sub_data_loader): - raise ValueError( - "the nums_per_dataset should be a list and the length should equal to sub_data_loader") - - if any([num < 0 for num in nums_per_dataset]): - raise ValueError( - f"the nums_per_dataset should be a list of positive value, but got {nums_per_dataset}") - + validate_nums_per_dataset(nums_per_dataset, sub_data_loader) need_sample = True if sub_data_loader_args is None: - sub_data_loader_args = dict() - - if not isinstance(shuffle, bool) and shuffle.lower() not in ["global", "files", "infile"]: - raise ValueError( - f"shuffle should be a bool or a str and the value must be one of ['global', 'files', 'infile']") + sub_data_loader_args = {} - if isinstance(shuffle, bool): - shuffle_dataset = shuffle - shuffle_file = shuffle - elif shuffle == Shuffle.INFILE: - shuffle_dataset = False - shuffle_file = True - elif shuffle == Shuffle.FILES: - shuffle_dataset = True - shuffle_file = False - else: - shuffle_dataset = True - shuffle_file = True + shuffle_dataset, shuffle_file = get_shuffle_flags(shuffle) sub_data_loader_args["shuffle"] = shuffle_file dataset_loaders = [] @@ -175,6 +141,57 @@ class MultiSourceDataLoader: return dataset +def validate_dataset_ratios(dataset_ratios: list[float], sub_data_loader: list, samples_count: Optional[int]) -> None: + """Validates the validity of dataset ratio parameters for data distribution.""" + if any(ratios < 0 for ratios in dataset_ratios): + raise ValueError( + f"the dataset_ratios should be a list of positive value, but got {dataset_ratios}") + + if abs(sum(dataset_ratios) - 1) > 1e-5: + raise ValueError("the sum of ratios is not equals to 1") + + if samples_count is None or samples_count <= 0: + raise ValueError(f"the samples_count should be a positive int when dataset_ratios is not None, " + f"but got {samples_count}") + + if not isinstance(dataset_ratios, list) or len(dataset_ratios) != len(sub_data_loader): + raise ValueError( + "the dataset_ratios should be a list and the length should equal to sub_data_loader") + + +def validate_nums_per_dataset(nums_per_dataset: list[int], sub_data_loader: list) -> None: + """Validates the validity of per-dataset sample count parameters.""" + if not isinstance(nums_per_dataset, list) or len(nums_per_dataset) != len(sub_data_loader): + raise ValueError( + "the nums_per_dataset should be a list and the length should equal to sub_data_loader") + + if any(num < 0 for num in nums_per_dataset): + raise ValueError( + f"the nums_per_dataset should be a list of positive value, but got {nums_per_dataset}") + + +def get_shuffle_flags(shuffle): + """Resolves shuffle configuration into dataset-level and file-level shuffle flags.""" + if not isinstance(shuffle, bool) and shuffle.lower() not in ["global", "files", "infile"]: + raise ValueError( + "shuffle should be a bool or a str and the value must be one of ['global', 'files', 'infile']") + + if isinstance(shuffle, bool): + shuffle_dataset = shuffle + shuffle_file = shuffle + elif shuffle == Shuffle.INFILE: + shuffle_dataset = False + shuffle_file = True + elif shuffle == Shuffle.FILES: + shuffle_dataset = True + shuffle_file = False + else: + shuffle_dataset = True + shuffle_file = True + + return shuffle_dataset, shuffle_file + + def prepare_generator_sub_dataloader_args(class_name, full_args): cls_obj = MindFormerRegister.get_cls(module_type="dataset_loader", class_name=class_name) @@ -350,15 +367,15 @@ class MultiSourceRandomAccessDataset: self.data_sample_index = load_dict["data_sample_index"] if save_indices_npz_path is not None: if is_publicly_accessible_path(save_indices_npz_path): - logger.info(f".......... npz file is saved in shared path ..........") + logger.info(".......... npz file is saved in shared path ..........") if self.rank_id == 0: logger.info(f".......... save indices to npz file: {save_indices_npz_path} by rank_{self.rank_id} ." f".........") np.savez_compressed(save_indices_npz_path, dataloader_index=self.dataloader_index, data_sample_index=self.data_sample_index) else: - logger.warning(f"If the npz file is being saved to a shared path, please add this path to the " - f"environment variable SHARED_PATHS, otherwise, please ignore this warning.") + logger.warning("If the npz file is being saved to a shared path, please add this path to the " + "environment variable SHARED_PATHS, otherwise, please ignore this warning.") if self.rank_id % get_device_num_per_node() == 0: logger.info(f".......... save indices to npz file: {save_indices_npz_path} by rank_{self.rank_id} ." f".........") diff --git a/mindformers/models/modeling_utils.py b/mindformers/models/modeling_utils.py index 36b9790d7e1c4924148395a656091489a76c7e08..14d404f994cd4a5b552adbb6d9d0435366e1f676 100644 --- a/mindformers/models/modeling_utils.py +++ b/mindformers/models/modeling_utils.py @@ -28,6 +28,8 @@ import mindspore as ms from mindspore import nn from mindspore import load_checkpoint, load_param_into_net from mindspore import context, Model + +import mindformers.models.auto as auto_module from mindformers.tools.check_rules import check_yaml_depth_before_loading from mindformers.tools.hub import ( PushToHubMixin, @@ -46,11 +48,10 @@ from mindformers.tools.ckpt_transform import TransformCkpt, make_soft_link from mindformers.tools.utils import ( get_real_rank, get_output_root_path, - clear_auto_trans_output, - remake_folder, - barrier_world, FILE_PERMISSION ) +from mindformers.utils.file_utils import clear_auto_trans_output, remake_folder +from mindformers.utils.parallel_utils import barrier_world from mindformers.models.utils import DEFAULT_CHECKPOINT_SAVE_FOLDER from mindformers.parallel_core.utils.model_mixin import ModelMixin from ..mindformer_book import MindFormerBook, print_path_or_list @@ -494,7 +495,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): if ( filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) - and filename not in shards.keys() + and filename not in shards and is_main_process and reg.fullmatch(filename_no_suffix) is not None ): @@ -570,7 +571,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): meraged_dict = {} if os.path.exists(config_path): - with open(config_path, 'r') as file_reader: + with open(config_path, 'r', encoding='utf-8') as file_reader: check_yaml_depth_before_loading(file_reader) file_reader.seek(0) meraged_dict = yaml.safe_load(file_reader.read()) @@ -669,7 +670,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): config_args = MindFormerConfig(yaml_file) kwargs["checkpoint_name_or_path"] = kwargs.get("checkpoint_name_or_path") \ - if "checkpoint_name_or_path" in kwargs.keys() else ckpt_file + if "checkpoint_name_or_path" in kwargs else ckpt_file config_args.model.model_config.update(**kwargs) logger.info("model config: %s and checkpoint_name_or_path: %s are used for " "model building.", yaml_file, config_args.model.model_config.checkpoint_name_or_path) @@ -712,7 +713,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): try_sync_file(yaml_file) config_args = MindFormerConfig(yaml_file) kwargs["checkpoint_name_or_path"] = kwargs.get("checkpoint_name_or_path") \ - if "checkpoint_name_or_path" in kwargs.keys() else pretrained_model_name_or_dir + if "checkpoint_name_or_path" in kwargs else pretrained_model_name_or_dir config_args.model.model_config.update(**kwargs) return config_args @@ -1415,8 +1416,6 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): if not isinstance(auto_class, str): auto_class = auto_class.__name__ - import mindformers.models.auto as auto_module - if not hasattr(auto_module, auto_class): raise ValueError(f"{auto_class} is not a valid auto class.") diff --git a/mindformers/tools/__init__.py b/mindformers/tools/__init__.py index 8f4a2733a8b02252e9cdf3f5c73991648a38be49..b4d4eefa0567c3dc199f9ec0d909d0773e8b3859 100644 --- a/mindformers/tools/__init__.py +++ b/mindformers/tools/__init__.py @@ -34,8 +34,6 @@ from .utils import ( count_params, get_output_root_path, get_output_subpath, - set_output_path, - set_strategy_save_path, str2bool, calculate_pipeline_stage, is_last_pipeline_stage diff --git a/mindformers/tools/ckpt_transform/transform_checkpoint.py b/mindformers/tools/ckpt_transform/transform_checkpoint.py index bca8bf70868386c6ff19266afa81e89fbc168efc..0ee7d25b0633e10ed34f6bbee089c20b0e1d938e 100644 --- a/mindformers/tools/ckpt_transform/transform_checkpoint.py +++ b/mindformers/tools/ckpt_transform/transform_checkpoint.py @@ -30,13 +30,15 @@ from mindformers.tools.utils import ( get_output_root_path, get_remote_save_url, get_device_num_per_node, + is_main_rank, + format_path +) +from mindformers.utils.file_utils import ( create_file, delete_file, - remake_folder, - is_main_rank, - format_path, - barrier_world + remake_folder ) +from mindformers.utils.parallel_utils import barrier_world from mindformers.tools.logger import logger from mindformers.tools.cloud_adapter import mox_adapter from mindformers.tools.ckpt_transform.utils import ( @@ -107,36 +109,9 @@ class TransformCkpt: self.node_num = self.world_size // self.npu_num_per_node if not is_power_of_two(self.npu_num_per_node): raise ValueError( - f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}") + f"The `npu_num_per_node` must be a power of 2, but get {self.npu_num_per_node}") - # Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size. - if transform_process_num < 1: - raise ValueError("transform_process_num should not smaller than 1," - f"but got {transform_process_num}.") - if transform_process_num > self.world_size: - logger.warning("transform_process_num: %d should not bigger than world_size: %d. \ - transform_process_num is set to %d.", - transform_process_num, self.world_size, self.world_size) - transform_process_num = self.world_size - if self.world_size % transform_process_num != 0: - raise ValueError(f"transform_process_num: {transform_process_num} " - f"should be divided by world_size: {self.world_size}.") - if check_in_modelarts() and 1 < transform_process_num < self.node_num: - logger.warning("transform_process_num: %d should not smaller than \ - node_num = world_size // npu_num_per_node = %d when training on AICC. \ - transform_process_num is set to node num = %d", - transform_process_num, self.node_num, self.node_num) - transform_process_num = self.world_size // npu_num_per_node - if check_in_modelarts() and transform_process_num == 1: - # The 0th NPU of each node is responsible for transform all checkpoints. - # For example, if world_size is 16 and npu_num_per_node is 8, then transform_rank_id_list should be [0,8]. - self.transform_rank_id_list = [i for i in range(0, self.world_size, self.npu_num_per_node)] - else: - # Obtain transform_rank_id_list. For example, - # if world_size is 8 and transform_process_num is 2, then transform_rank_id_list should be [0,4]. - # which means that the 0th rank and the 4th rank responsible for transform checkpoints. - self.transform_rank_id_list = \ - [i for i in range(0, self.world_size, self.world_size // transform_process_num)] + self.transform_rank_id_list = self._get_transform_rank_id_list(transform_process_num) self.transform_process_num = len(self.transform_rank_id_list) if auto_trans_ckpt: @@ -377,10 +352,10 @@ class TransformCkpt: raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!") checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime) checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1] - save_checkpoint_dir = os.path.join(dst_checkpoint, "rank_{}".format(current_transform_rank_id)) + save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}") os.makedirs(save_checkpoint_dir, exist_ok=True) save_checkpoint_path = os.path.join(save_checkpoint_dir, - "{}.ckpt".format(prefix + str(current_transform_rank_id))) + f"{prefix + str(current_transform_rank_id)}.ckpt") logger.info("rank_list: %s", src_rank_list) logger.info("checkpoint_file_map: %s", checkpoint_file_map) logger.info("save_checkpoint_path: %s", save_checkpoint_path) @@ -452,7 +427,7 @@ class TransformCkpt: if rank_id: merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy_by_rank_{rank_id}.ckpt') else: - merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy.ckpt') + merge_path = os.path.join(strategy_path, 'merged_ckpt_strategy.ckpt') merged_succeed_txt = os.path.join(strategy_path, "merge_succeed.txt") if self.is_main_rank: @@ -516,6 +491,67 @@ class TransformCkpt: return dst_strategy + def _get_transform_rank_id_list(self, transform_process_num): + """ + Generates a list of distributed rank IDs responsible for checkpoint transformation. + + This internal method calculates and returns the list of rank IDs assigned to handle data/checkpoint + transformation in a distributed training environment. It first validates input constraints for + `transform_process_num`, applies environment-specific adjustments (for ModelArts/AICC platforms), + and then computes the rank IDs based on cluster configuration (world size, per-node NPU count). + + The rank ID assignment follows two core strategies: + 1. ModelArts/AICC environment with `transform_process_num=1`: Assigns the 0th NPU of each node + (e.g., world_size=16, npu_num_per_node=8 → rank IDs [0, 8]) + 2. Other scenarios: Distributes rank IDs evenly across the total world size + (e.g., world_size=8, transform_process_num=2 → rank IDs [0, 4]) + + Args: + transform_process_num (int): Number of processes allocated for transformation tasks. + Must be a positive integer that divides `self.world_size` (after potential clamping). + + Returns: + list[int]: Sorted list of rank IDs designated for transformation. + The length of the list equals the final `transform_process_num` (after adjustments). + + Raises: + ValueError: If: + - `transform_process_num` is less than 1 + - `self.world_size` is not divisible by `transform_process_num` (before platform-specific adjustments) + """ + # Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size. + if transform_process_num < 1: + raise ValueError("transform_process_num should not smaller than 1," + f"but got {transform_process_num}.") + if transform_process_num > self.world_size: + logger.warning(f"transform_process_num: {transform_process_num} should not " + f"bigger than world_size: {self.world_size}. " + f"transform_process_num is set to {self.world_size}.") + transform_process_num = self.world_size + if self.world_size % transform_process_num != 0: + raise ValueError(f"transform_process_num: {transform_process_num} " + f"should be divided by world_size: {self.world_size}.") + + if check_in_modelarts() and 1 < transform_process_num < self.node_num: + logger.warning("transform_process_num: %d should not smaller than \ + node_num = world_size // npu_num_per_node = %d when training on AICC. \ + transform_process_num is set to node num = %d", + transform_process_num, self.node_num, self.node_num) + transform_process_num = self.world_size // self.npu_num_per_node + + if check_in_modelarts() and transform_process_num == 1: + # The 0th NPU of each node is responsible for transform all checkpoints. + # For example, if world_size is 16 and npu_num_per_node is 8, then transform_rank_id_list should be [0,8]. + transform_rank_id_list = list(range(0, self.world_size, self.npu_num_per_node)) + else: + # Obtain transform_rank_id_list. For example, + # if world_size is 8 and transform_process_num is 2, then transform_rank_id_list should be [0,4]. + # which means that the 0th rank and the 4th rank responsible for transform checkpoints. + transform_rank_id_list = list(range(0, self.world_size, self.world_size // transform_process_num)) + + return transform_rank_id_list + + @staticmethod def check_src_checkpoint_and_strategy(src_checkpoint, src_strategy): """check src checkpoint and strategy""" @@ -599,12 +635,12 @@ class TransformCkpt: if check_in_modelarts(): transformed_ckpt_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs, os.path.basename(ckpt_dir)) transform_failed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs, - f'transform_failed_rank_*.txt')) + 'transform_failed_rank_*.txt')) transform_succeed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs, - f'transform_succeed_rank_*.txt')) + 'transform_succeed_rank_*.txt')) else: - transform_failed_txts = glob(os.path.join(ckpt_dir, f'transform_failed_rank_*.txt')) - transform_succeed_txts = glob(os.path.join(ckpt_dir, f'transform_succeed_rank_*.txt')) + transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt')) + transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt')) if transform_failed_txts: raise ValueError(f"Transform failed, find {transform_failed_txts}.") current_count = len(transform_succeed_txts) diff --git a/mindformers/tools/resume_ckpt.py b/mindformers/tools/resume_ckpt.py index d3dd9c603a9b2a76caa31224e531d2df88eca92b..7a9003e4a80f3b5b5e53870b97089ee60d370db1 100644 --- a/mindformers/tools/resume_ckpt.py +++ b/mindformers/tools/resume_ckpt.py @@ -18,7 +18,6 @@ import os from mindformers.tools.logger import logger from mindformers.tools.utils import ( - create_file, check_in_modelarts, check_ckpt_file_name, get_real_rank, @@ -27,12 +26,15 @@ from mindformers.tools.utils import ( get_rank_id_from_ckpt_name, get_remote_save_url, replace_rank_id_in_ckpt_name, - remake_folder, - is_publicly_accessible_path, get_output_root_path, - barrier_world, Validator ) +from mindformers.utils.file_utils import ( + create_file, + remake_folder, + is_publicly_accessible_path +) +from mindformers.utils.parallel_utils import barrier_world from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existed_in_sf_dir if check_in_modelarts(): diff --git a/mindformers/tools/utils.py b/mindformers/tools/utils.py index ce6bea613d8a49b89096bc245cbace9e3eb56df6..831eb5a877a22d611f29b15b626111fa2b07a363 100644 --- a/mindformers/tools/utils.py +++ b/mindformers/tools/utils.py @@ -16,7 +16,6 @@ import json import os import re -import shutil import stat import tempfile from multiprocessing import Process @@ -33,8 +32,8 @@ except ImportError: import mindspore as ms from mindspore import Tensor, context -from mindspore._checkparam import args_type_check -from mindspore.communication import get_group_size, get_rank, comm_func, get_local_rank +from mindspore.communication import get_group_size, get_rank, get_local_rank + PARALLEL_MODE = {'DATA_PARALLEL': context.ParallelMode.DATA_PARALLEL, 'SEMI_AUTO_PARALLEL': context.ParallelMode.SEMI_AUTO_PARALLEL, @@ -85,7 +84,7 @@ class Validator: def check_type(arg_value, arg_type): """Check int.""" if not isinstance(arg_value, arg_type): - raise TypeError('{} should be {} type, but get {}'.format(arg_value, arg_type, type(arg_value))) + raise TypeError(f'{arg_value} should be {arg_type} type, but get {type(arg_value)}') @staticmethod def is_obs_url(url): @@ -96,11 +95,11 @@ class Validator: def check_obs_url(url): """Check obs url.""" if not isinstance(url, str): - raise TypeError('remote_save_url type should be a str, but get {}, ' - 'please check your remote_save_url config'.format(type(url))) + raise TypeError(f'remote_save_url type should be a str, but get {type(url)}, ' + 'please check your remote_save_url config') if not (url.startswith(_PROTOCOL + '://') or url.startswith(_PROTOCOL_S3 + '://')): raise TypeError('remote_save_url should be start with obs:// or s3://, ' - 'but get {}, please check your remote_save_url config'.format(url)) + f'but get {url}, please check your remote_save_url config') def check_list(var_name: str, list_var: Union[Tuple, List], num: int): @@ -116,7 +115,7 @@ def check_list(var_name: str, list_var: Union[Tuple, List], num: int): """ for value in list_var: if value >= num: - raise ValueError('The index of the {} needs to be less than the number of nodes {}.'.format(var_name, num)) + raise ValueError(f'The index of the {var_name} needs to be less than the number of nodes {num}.') def check_file(file_path, file_type=None): @@ -154,40 +153,6 @@ def get_output_root_path(): return os.path.realpath(expanduser_path) -@args_type_check(path=str) -def set_output_path(path): - """set output path""" - from .logger import logger - if path is None: - path = './output' - expanduser_path = os.path.expanduser(path) - os.environ['LOCAL_DEFAULT_PATH'] = os.path.realpath(expanduser_path) - logger.info(f"set output path to '{os.path.realpath(expanduser_path)}'") - - -def set_strategy_save_path(config): - """set strategy path""" - from .logger import logger - rank_id = get_real_rank() - strategy_ckpt_save_dir = os.path.join(get_output_root_path(), "strategy") - os.makedirs(strategy_ckpt_save_dir, exist_ok=True) - set_safe_mode_for_file_or_dir(strategy_ckpt_save_dir) - - strategy_ckpt_save_file = config.get('strategy_ckpt_save_file', "ckpt_strategy.ckpt") - if not strategy_ckpt_save_file.endswith(f"_rank_{rank_id}.ckpt"): - strategy_name = os.path.basename(strategy_ckpt_save_file).replace(".ckpt", f"_rank_{rank_id}.ckpt") - config['strategy_ckpt_save_file'] = os.path.join(strategy_ckpt_save_dir, strategy_name) - context.set_auto_parallel_context(strategy_ckpt_save_file=config['strategy_ckpt_save_file']) - logger.info(f"set strategy path to '{config['strategy_ckpt_save_file']}'") - -def set_checkpoint_save_path(): - """set checkpoint save path""" - from .logger import logger - checkpoint_save_path = os.path.join(get_output_root_path(), "checkpoint") - os.makedirs(checkpoint_save_path, exist_ok=True) - set_safe_mode_for_file_or_dir(checkpoint_save_path) - logger.info(f"set checkpoint save path to `{checkpoint_save_path}`") - def get_log_path(): path = os.getenv("LOG_MF_PATH", os.path.join(get_output_root_path(), "log")) return os.path.expanduser(path) @@ -199,7 +164,7 @@ def get_output_subpath(sub_class, rank_id=0, append_rank=True): root_path = get_output_root_path() directory = os.path.join(root_path, sub_class) if append_rank: - directory = os.path.join(directory, 'rank_{}'.format(rank_id)) + directory = os.path.join(directory, f'rank_{rank_id}') return format_path(directory) @@ -245,12 +210,13 @@ def get_num_nodes_devices(rank_size: int) -> Tuple[int, int]: num_nodes (int): number of nodes. num_devices (int): number of devices. """ - if rank_size in (2, 4, 8): + device_num_per_node = get_device_num_per_node() + if rank_size <= device_num_per_node: num_nodes = 1 num_devices = rank_size else: - num_nodes = rank_size // 8 - num_devices = 8 + num_nodes = rank_size // device_num_per_node + num_devices = device_num_per_node return num_nodes, num_devices @@ -260,9 +226,9 @@ class Const: def __setattr__(self, key, value): if key in self.__dict__: - raise PermissionError('Can not change const {0}.'.format(key)) + raise PermissionError(f'Can not change const {key}.') if not key.isupper(): - raise ValueError('Const name {0} is not all uppercase.'.format(key)) + raise ValueError(f'Const name {key} is not all uppercase.') self.__dict__[key] = value @@ -337,7 +303,7 @@ def count_params(net): def try_sync_file(file_name): """If the file is still downloading, we need to wait before the file finished downloading""" if fcntl: - with open(file_name, 'r') as fp: + with open(file_name, 'r', encoding='utf-8') as fp: fcntl.flock(fp.fileno(), fcntl.LOCK_EX) @@ -389,16 +355,14 @@ def parse_value(value): b = int(a) except (TypeError, ValueError): return False - else: - return a == b + return a == b def isfloat(x): try: float(x) except (TypeError, ValueError): return False - else: - return True + return True def isbool(x): return x in ["True", "False"] @@ -408,8 +372,7 @@ def parse_value(value): json.loads(x) except json.decoder.JSONDecodeError: return False - else: - return True + return True if isint(value): return int(value) @@ -478,7 +441,7 @@ def get_dp_from_dataset_strategy(): first_input_stra = data_strategy[0] dp = int(first_input_stra[0]) else: - raise TypeError(f"Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") + raise TypeError("Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") return dp @@ -512,116 +475,6 @@ def is_last_pipeline_stage(): return (rank // device_num_per_stage + 1) == stage_num -def is_publicly_accessible_path(path): - """Check a path is accessible by all rank.""" - from .logger import logger - if get_real_group_size() <= get_device_num_per_node(): - return True - - if check_in_modelarts(): - return True - - if check_shared_disk(path): - return True - - # For example, SHARED_PATHS="/mnt/shared1,/mnt/shared2", which will be split by "/mnt/shared1" and "/mnt/shared2". - shared_paths = os.getenv("SHARED_PATHS", "").split(',') - path = os.path.realpath(path) - for shared_path in shared_paths: - if not shared_path: - continue - shared_path = os.path.realpath(shared_path) - if path.startswith(shared_path): - return True - logger.info("System can not identify if given path is shared disk. " - "If it is, Please set env 'SHARED_PATHS' to given path.") - return False - - -def create_file(file_path, info=None): - """create file.""" - if Validator.is_obs_url(file_path): - if not check_in_modelarts(): - raise ValueError(f"When create {file_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - with mox.file.File(file_path, 'w') as f: - if info: - if isinstance(info, list): - for sub_info in info: - f.write(str(sub_info) + "\n") - else: - f.write(info) - else: - f.write("Hugging ModelArts.") - else: - flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC - with os.fdopen(os.open(file_path, flags_, FILE_PERMISSION), 'w') as f: - if info: - if isinstance(info, list): - for sub_info in info: - f.write(str(sub_info) + "\n") - else: - f.write(info) - - -def delete_file(file_path): - """delete file""" - if Validator.is_obs_url(file_path): - if not check_in_modelarts(): - raise ValueError(f"When create {file_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if mox.file.exists(file_path): - mox.file.remove(file_path, recursive=False) - else: - if os.path.exists(file_path): - os.remove(file_path) - - -def remake_folder(folder_path, permissions=None): - """make folder""" - from .logger import logger - rank_id = get_real_rank() - logger.info("Remake %s...", folder_path) - if Validator.is_obs_url(folder_path): - if not check_in_modelarts(): - raise ValueError(f"When remaking {folder_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if not rank_id: - if mox.file.exists(folder_path): - mox.file.remove(folder_path, recursive=True) - mox.file.make_dirs(folder_path) - logger.info("OBS: Folder %s is remaked.", folder_path) - else: - if is_main_rank(): - if os.path.exists(folder_path): - shutil.rmtree(folder_path) - os.makedirs(folder_path, exist_ok=True) - os.chmod(folder_path, permissions) - logger.info("Folder %s is remaked.", folder_path) - - -def remove_folder(folder_path, rank_id=None): - """delete folder""" - from .logger import logger - rank_id = rank_id or get_real_rank() - logger.info("Remove %s...", folder_path) - if Validator.is_obs_url(folder_path): - if not check_in_modelarts(): - raise ValueError(f"When removing {folder_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if mox.file.exists(folder_path) and not rank_id: - mox.file.remove(folder_path, recursive=True) - logger.info("OBS: Folder %s is removed.", folder_path) - else: - if os.path.exists(folder_path) and is_main_rank(): - shutil.rmtree(folder_path) - logger.info("Folder %s is removed.", folder_path) - - def set_safe_mode_for_file_or_dir(path): if isinstance(path, str): path = [path] @@ -688,22 +541,6 @@ def replace_rank_id_in_ckpt_name(ckpt_file, dst_rank_id): return ckpt_name -def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None): - """clear transformed_checkpoint and strategy""" - folder_list = ["strategy", "transformed_checkpoint"] - for folder in folder_list: - if check_in_modelarts(): - folder_path = os.path.join(get_remote_save_url(), folder) - else: - folder_path = os.path.join(get_output_root_path(), folder) - if os.path.realpath(folder_path) in (load_checkpoint, src_strategy_path_or_dir): - raise ValueError( - "./transformed_checkpoint or ./strategy with given config.output_dir is same as " - "load_checkpoint or src_strategy_path_or_dir which is not allowed when auto_trans is True." - "Please move it to a different location or specify a different output folder.") - remake_folder(folder_path, permissions=0o750) - - def check_ckpt_file_name(ckpt_file, ckpt_fmt='ckpt'): """Check ckpt name in the format of {prefix}-{epoch}_{step}.ckpt""" ckpt_name = os.path.split(ckpt_file)[1] @@ -736,18 +573,6 @@ def is_pynative(): return enforce_eager.lower() == "true" -def barrier_world(action: str = None): - """barrier all rank until action is done""" - if get_real_group_size() > 1: - from .logger import logger - if action is not None: - logger.info("Wait " + str(action)) - else: - logger.info("Now barriered...") - - comm_func.barrier() - - def get_pipeline_rank_ids(): """Calculate rank id of each stage and return a list of first rank id in each stage. @@ -770,7 +595,7 @@ def get_pipeline_rank_ids(): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" if numerator % denominator != 0: - raise ValueError("{} is not divisible by {}".format(numerator, denominator)) + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator, denominator): diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 60ce12c29f236a7c7389feda138727963f037c31..8b1bafc4dace848858523ca9528929444a73bdfe 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -42,8 +42,6 @@ from mindformers.models import PreTrainedModel, BaseImageProcessor, \ PreTrainedTokenizerBase, BaseAudioProcessor from mindformers.models.utils import WEIGHTS_NAME from mindformers.tools.utils import ( - set_output_path, - set_strategy_save_path, set_context, check_in_modelarts, get_real_rank, @@ -51,9 +49,14 @@ from mindformers.tools.utils import ( set_remote_save_url, get_output_root_path, get_device_num_per_node, + FILE_PERMISSION, +) +from mindformers.utils.file_utils import ( + set_output_path, + set_strategy_save_path, is_publicly_accessible_path, - clear_auto_trans_output, - FILE_PERMISSION, set_checkpoint_save_path + set_checkpoint_save_path, + clear_auto_trans_output ) from mindformers.tools.logger import logger from mindformers.tools.register import MindFormerConfig @@ -1396,12 +1399,13 @@ class Trainer: def _validate_auto_trans_ckpt_requirements(self): """Ensure auto_trans_ckpt has shared directory requirements met.""" - if self.config.auto_trans_ckpt and self.config.load_ckpt_format == 'ckpt': + if self.config.auto_trans_ckpt: if not is_publicly_accessible_path(get_output_root_path()): raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True, " f"the output_dir should be a shared directory that can be accessed by all nodes. " f"But {os.path.abspath(self.config.output_dir)} is not a shared directory.") - clear_auto_trans_output(self.config.load_checkpoint, self.config.src_strategy_path_or_dir) + clear_auto_trans_output( + self.config.load_checkpoint, self.config.src_strategy_path_or_dir, self.config.load_ckpt_format) def _using_checkpoint_name_or_path_if_needed(self): """Using model_config.checkpoint_name_or_path if possible.""" diff --git a/mindformers/trainer/utils.py b/mindformers/trainer/utils.py index 971dfcff3e39d1de5b111e49d7ee1fad20117157..fb306cef0f71f6dcae2c82bb3cfe98438b7fe7df 100644 --- a/mindformers/trainer/utils.py +++ b/mindformers/trainer/utils.py @@ -27,7 +27,9 @@ from mindspore import context, load_checkpoint, load_param_into_net, load_checkp from mindspore import set_seed as ms_set_seed from mindspore import Parameter from mindspore import ops, mint +from mindspore import save_checkpoint from mindspore.communication.comm_func import barrier +from mindspore.common.file_system import mindio_preload, set_mindio_server_info from mindformers.tools.logger import logger from mindformers.tools.utils import get_real_rank @@ -126,37 +128,36 @@ def preload_ckpt(config): """Preload data into memory using MindIO for faster access.""" rank_id = get_real_rank() if get_real_rank() else 0 ckpt_path = config.load_checkpoint - try: - from mindspore.common.file_system import _init_mindio, mindio_preload, set_mindio_server_info - except ImportError: - return mindio_pool_capacity = config.get("mindio_pool_capacity", 128) set_mindio_server_info(mindio_pool_capacity) - if hasattr(_init_mindio(), "preload"): - logger.info("MindIO is initialized successfully!") - else: - return + if not os.path.realpath(ckpt_path) or not os.path.exists(ckpt_path): raise FileNotFoundError(f"The load_checkpoint must be correct, but get {ckpt_path}") + + preload_ok = False if os.path.isfile(ckpt_path): # only preload the ckpt file once. if is_main_rank() or f"rank_{rank_id}" in ckpt_path: logger.info(f"MindIO preloading `{ckpt_path}`...") - mindio_preload(ckpt_path) + preload_ok = mindio_preload(ckpt_path) elif os.path.isdir(ckpt_path) and check_ckpt_file_exist(ckpt_path): for ckpt_file in os.listdir(ckpt_path): # only preload every ckpt file once in rank_0 process. if ckpt_file.endswith('.ckpt') and is_main_rank(): checkpoint_path = os.path.join(ckpt_path, ckpt_file) logger.info(f"MindIO preloading `{checkpoint_path}`...") - mindio_preload(checkpoint_path) + preload_ok = mindio_preload(checkpoint_path) elif os.path.isdir(ckpt_path) and check_rank_folders(ckpt_path, rank_id): # preload ckpt file of rank in corresponding rank process. checkpoint_path = get_distribute_checkpoint_path(ckpt_path) logger.info(f"MindIO preloading `{checkpoint_path}`...") - mindio_preload(checkpoint_path) + preload_ok = mindio_preload(checkpoint_path) else: raise ValueError(f"{ckpt_path} is not a valid path to load checkpoint when auto_trans_ckpt is False.") + + if preload_ok: + logger.info("MindIO preload checkpoint successfully!") + if config.use_parallel: barrier() @@ -312,7 +313,7 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck "load_checkpoint should be a checkpoint directory containing the directory of rank_{0-*}," "The directory structure is as follows: **checkpoint_root_dir/rank_{0-*}/**.ckpt") rank_id = rank_id if rank_id is not None else get_real_rank() - distribute_checkpoint_dir = os.path.join(checkpoint_dir, "rank_{}".format(rank_id)) + distribute_checkpoint_dir = os.path.join(checkpoint_dir, f"rank_{rank_id}") distribute_checkpoint_path = get_last_checkpoint(distribute_checkpoint_dir, ckpt_format) logger.info("distribute checkpoint dir: %s", distribute_checkpoint_dir) elif os.path.isfile(checkpoint_dir): @@ -350,7 +351,9 @@ def load_resume_context_from_checkpoint(config, dataset): raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}") if os.path.isdir(config.load_checkpoint): - if config.use_graceful_exit: + # When graceful exit is enabled or auto checkpoint transformation is disabled, + # use real rank ID to locate the checkpoint. + if config.use_graceful_exit or not config.auto_trans_ckpt: rank_id = get_real_rank() else: rank_id = 0 @@ -568,14 +571,14 @@ def load_slora_ckpt(checkpoint_dict, config, network): adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json") if not os.path.exists(adapter_path): raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}") - with open(adapter_path, 'r') as file: + with open(adapter_path, 'r', encoding='utf-8') as file: path_dict = json.load(file) adapter_list = [] config_list = [] for adapter_name in network.lora_adapter.adapter_names[1:]: if adapter_name in path_dict.keys(): adapter_model = load_checkpoint(os.path.join(path_dict[adapter_name], "adapter_model.ckpt")) - with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r') as file: + with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r', encoding='utf-8') as file: adapter_config = json.load(file) else: adapter_model = {} @@ -611,7 +614,6 @@ def load_slora_ckpt(checkpoint_dict, config, network): dst_checkpoint_dir = pet_config.adapter_path if config.auto_trans_ckpt: # Save collected lora weights as single ckpt - from mindspore import save_checkpoint src_checkpoint_dir = os.path.join(config.output_dir, "slora_checkpoint") os.makedirs(src_checkpoint_dir, exist_ok=True) src_checkpoint_dir = os.path.join(src_checkpoint_dir, "slora.ckpt") diff --git a/mindformers/utils/__init__.py b/mindformers/utils/__init__.py index 801fe92793252fb3718dd31c5e409edc77cdc30b..38397eacfe23a77f58577394f5aa4e3b70d13dd9 100644 --- a/mindformers/utils/__init__.py +++ b/mindformers/utils/__init__.py @@ -28,3 +28,15 @@ from .safetensors import ( check_safetensors_key, ) from .decorators import deprecated +from .parallel_utils import barrier_world +from .file_utils import ( + set_output_path, + set_strategy_save_path, + set_checkpoint_save_path, + is_publicly_accessible_path, + remake_folder, + remove_folder, + create_file, + delete_file, + clear_auto_trans_output +) diff --git a/mindformers/utils/file_utils.py b/mindformers/utils/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d99a4c0fa4cffb834b02d1f77c92fcbb6f9846 --- /dev/null +++ b/mindformers/utils/file_utils.py @@ -0,0 +1,212 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" File utils.""" +import os +import shutil + +from mindspore import context +from mindspore._checkparam import args_type_check + +from mindformers.tools.logger import logger +from mindformers.tools.utils import ( + Validator, + FILE_PERMISSION, + check_in_modelarts, + get_real_rank, + set_safe_mode_for_file_or_dir, + get_output_root_path, + get_real_group_size, + get_device_num_per_node, + get_remote_save_url, + check_shared_disk, + is_main_rank +) + +if check_in_modelarts(): + import moxing as mox +else: + mox = None + + +@args_type_check(path=str) +def set_output_path(path): + """set output path""" + if path is None: + path = './output' + expanduser_path = os.path.expanduser(path) + os.environ['LOCAL_DEFAULT_PATH'] = os.path.realpath(expanduser_path) + logger.info(f"set output path to '{os.path.realpath(expanduser_path)}'") + + +def set_strategy_save_path(config): + """set strategy path""" + rank_id = get_real_rank() + strategy_ckpt_save_dir = os.path.join(get_output_root_path(), "strategy") + os.makedirs(strategy_ckpt_save_dir, exist_ok=True) + set_safe_mode_for_file_or_dir(strategy_ckpt_save_dir) + + strategy_ckpt_save_file = config.get('strategy_ckpt_save_file', "ckpt_strategy.ckpt") + if not strategy_ckpt_save_file.endswith(f"_rank_{rank_id}.ckpt"): + strategy_name = os.path.basename(strategy_ckpt_save_file).replace(".ckpt", f"_rank_{rank_id}.ckpt") + config['strategy_ckpt_save_file'] = os.path.join(strategy_ckpt_save_dir, strategy_name) + context.set_auto_parallel_context(strategy_ckpt_save_file=config['strategy_ckpt_save_file']) + logger.info(f"set strategy path to '{config['strategy_ckpt_save_file']}'") + + +def set_checkpoint_save_path(): + """set checkpoint save path""" + checkpoint_save_path = os.path.join(get_output_root_path(), "checkpoint") + os.makedirs(checkpoint_save_path, exist_ok=True) + set_safe_mode_for_file_or_dir(checkpoint_save_path) + logger.info(f"set checkpoint save path to `{checkpoint_save_path}`") + + +def is_publicly_accessible_path(path): + """Check a path is accessible by all rank.""" + if get_real_group_size() <= get_device_num_per_node(): + return True + + if check_in_modelarts(): + return True + + if check_shared_disk(path): + return True + + # For example, SHARED_PATHS="/mnt/shared1,/mnt/shared2", which will be split by "/mnt/shared1" and "/mnt/shared2". + shared_paths = os.getenv("SHARED_PATHS", "").split(',') + path = os.path.realpath(path) + for shared_path in shared_paths: + if not shared_path: + continue + shared_path = os.path.realpath(shared_path) + if path.startswith(shared_path): + return True + logger.info("System can not identify if given path is shared disk. " + "If it is, Please set env 'SHARED_PATHS' to given path.") + return False + + +def remake_folder(folder_path, permissions=None): + """make folder""" + rank_id = get_real_rank() + logger.info("Remake %s...", folder_path) + if Validator.is_obs_url(folder_path): + if not check_in_modelarts(): + raise ValueError(f"When remaking {folder_path}, \ + it is detected that it is not in the ModelArts platform.") + if not rank_id: + if mox.file.exists(folder_path): + mox.file.remove(folder_path, recursive=True) + mox.file.make_dirs(folder_path) + logger.info("OBS: Folder %s is remaked.", folder_path) + else: + if is_main_rank(): + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + os.makedirs(folder_path, exist_ok=True) + os.chmod(folder_path, permissions) + logger.info("Folder %s is remaked.", folder_path) + + +def remove_folder(folder_path, rank_id=None): + """delete folder""" + rank_id = rank_id or get_real_rank() + logger.info("Remove %s...", folder_path) + if Validator.is_obs_url(folder_path): + if not check_in_modelarts(): + raise ValueError(f"When removing {folder_path}, \ + it is detected that it is not in the ModelArts platform.") + if mox.file.exists(folder_path) and not rank_id: + mox.file.remove(folder_path, recursive=True) + logger.info("OBS: Folder %s is removed.", folder_path) + else: + if os.path.exists(folder_path) and is_main_rank(): + shutil.rmtree(folder_path) + logger.info("Folder %s is removed.", folder_path) + + +def create_file(file_path, info=None): + """create file.""" + if Validator.is_obs_url(file_path): + if not check_in_modelarts(): + raise ValueError(f"When create {file_path}, \ + it is detected that it is not in the ModelArts platform.") + with mox.file.File(file_path, 'w') as f: + if info: + if isinstance(info, list): + for sub_info in info: + f.write(str(sub_info) + "\n") + else: + f.write(info) + else: + f.write("Hugging ModelArts.") + else: + flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + with os.fdopen(os.open(file_path, flags_, FILE_PERMISSION), 'w') as f: + if info: + if isinstance(info, list): + for sub_info in info: + f.write(str(sub_info) + "\n") + else: + f.write(info) + + +def delete_file(file_path): + """delete file""" + if Validator.is_obs_url(file_path): + if not check_in_modelarts(): + raise ValueError(f"When create {file_path}, \ + it is detected that it is not in the ModelArts platform.") + if mox.file.exists(file_path): + mox.file.remove(file_path, recursive=False) + else: + if os.path.exists(file_path): + os.remove(file_path) + + +def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None, load_ckpt_format='ckpt'): + """ + Clear transformed_checkpoint and strategy folders based on the specified checkpoint format. + + Args: + load_checkpoint (str, optional): Path to the checkpoint file/directory to load. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + src_strategy_path_or_dir (str, optional): Path to the source strategy file/directory. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + load_ckpt_format (str, optional): Format of the checkpoint file. Supports 'ckpt' and 'safetensors'. + Determines which transformed checkpoint folder to clear. Defaults to 'ckpt'. + + Raises: + ValueError: If `load_ckpt_format` is not one of the supported values ('ckpt' or 'safetensors'). + ValueError: If the resolved folder path (strategy or transformed/unified checkpoint) is the same as + `load_checkpoint` or `src_strategy_path_or_dir`. This prevents overwriting critical input data. + """ + if load_ckpt_format not in ('ckpt', 'safetensors'): + raise ValueError(f"Invalid checkpoint format '{load_ckpt_format}'. " + "Only 'ckpt' and 'safetensors' formats are supported.") + + folder_list = ["strategy", "transformed_checkpoint"] if load_ckpt_format == 'ckpt' \ + else ["strategy", "unified_checkpoint"] + for folder in folder_list: + if check_in_modelarts(): + folder_path = os.path.join(get_remote_save_url(), folder) + else: + folder_path = os.path.join(get_output_root_path(), folder) + if os.path.realpath(folder_path) in (load_checkpoint, src_strategy_path_or_dir): + raise ValueError( + "./transformed_checkpoint or ./strategy with given config.output_dir is same as " + "load_checkpoint or src_strategy_path_or_dir which is not allowed when auto_trans is True." + "Please move it to a different location or specify a different output folder.") + remake_folder(folder_path, permissions=0o750) diff --git a/mindformers/utils/load_checkpoint_utils.py b/mindformers/utils/load_checkpoint_utils.py index d23375f461eca27f26f23e5decc6d6a12c486cf9..2233f553284921da8f9f7a4eac5650d2ee07d59e 100644 --- a/mindformers/utils/load_checkpoint_utils.py +++ b/mindformers/utils/load_checkpoint_utils.py @@ -28,12 +28,8 @@ from mindspore.common.api import _pynative_executor from mindspore.communication.comm_func import barrier from mindformers.tools.logger import logger -from mindformers.tools.utils import ( - is_main_rank, - get_real_rank, - clear_auto_trans_output, - barrier_world -) +from mindformers.tools.utils import is_main_rank, get_real_rank +from mindformers.utils.parallel_utils import barrier_world from mindformers.utils import convert_hf_safetensors_multiprocess, check_safetensors_key, is_hf_safetensors_dir from mindformers.utils.safetensors.convert_safetensors import _convert_index_json from mindformers.checkpoint.utils import compile_model @@ -402,7 +398,6 @@ def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_paral ) unify_time_end = time.time() logger.info("Time spent unifying safetensors: %.2fs", unify_time_end - unify_time_start) - clear_auto_trans_output() else: logger.info("Wait for rank_0 to unify the safetensors, " "please check the log of rank_0 to get the unify progress.") diff --git a/mindformers/utils/parallel_utils.py b/mindformers/utils/parallel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4610f4bb57bbae5a8fba5e6a53fdc7db192e230 --- /dev/null +++ b/mindformers/utils/parallel_utils.py @@ -0,0 +1,31 @@ + +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Process utils.""" +from mindspore.communication import comm_func + +from mindformers.tools.logger import logger +from mindformers.tools.utils import get_real_group_size + + +def barrier_world(action: str = None): + """barrier all rank until action is done""" + if get_real_group_size() > 1: + if action is not None: + logger.info("Wait " + str(action)) + else: + logger.info("Now barriered...") + + comm_func.barrier() diff --git a/mindformers/utils/resume_ckpt_utils.py b/mindformers/utils/resume_ckpt_utils.py index 587f650316b59fb6b80132e28dde118cd19cb5ca..82628be9b9fdefd4eae8c467af1bb4b6cc8a4068 100644 --- a/mindformers/utils/resume_ckpt_utils.py +++ b/mindformers/utils/resume_ckpt_utils.py @@ -19,7 +19,8 @@ from mindspore import load_checkpoint from mindformers.tools.logger import logger from mindformers.tools.resume_ckpt import get_resume_checkpoint_by_meta -from mindformers.tools.utils import get_real_rank, is_publicly_accessible_path, replace_rank_id_in_ckpt_name +from mindformers.tools.utils import get_real_rank, replace_rank_id_in_ckpt_name +from mindformers.utils.file_utils import is_publicly_accessible_path from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existed_in_sf_dir diff --git a/run_mindformer.py b/run_mindformer.py index b28ca84e2c57ca131df6256ee042aa37046af364..9183d1dbca8aff91253d389f4b7595bd70212e5b 100644 --- a/run_mindformer.py +++ b/run_mindformer.py @@ -22,7 +22,7 @@ from mindformers.tools.utils import str2bool, parse_value, str2bool_or_str from mindformers.core.context import build_context from mindformers.trainer import Trainer from mindformers.tools.logger import logger -from mindformers.tools import set_output_path +from mindformers.utils.file_utils import set_output_path SUPPORT_MULTI_MODAL_FILETYPES = { "video": (".mp4", ".avi", ".mkv"), diff --git a/toolkit/benchmarks/eval_with_harness.py b/toolkit/benchmarks/eval_with_harness.py index a4da4805a861e9cb5685fe529915860cceae70fa..24cf4dadcab02cc25ab0890156a26dae364d1db7 100644 --- a/toolkit/benchmarks/eval_with_harness.py +++ b/toolkit/benchmarks/eval_with_harness.py @@ -43,7 +43,7 @@ from mindformers import ( AutoTokenizer ) from mindformers.trainer.utils import transform_and_load_checkpoint -from mindformers.tools import set_output_path +from mindformers.utils.file_utils import set_output_path from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert eval_logger = utils.eval_logger @@ -413,8 +413,8 @@ class MFLM(TemplateLM): if not continuation_enc: raise ValueError("continuation_enc must not be None") if len(continuation_enc) > self.max_length: - raise ValueError("The length of continuation_enc must be less than \ - or equal to max_length, but got {}".format(len(continuation_enc))) + raise ValueError("The length of continuation_enc must be less than " + f"or equal to max_length, but got {len(continuation_enc)}") # how this all works (illustrated on a causal decoder-only setup): # CTX CONT @@ -456,7 +456,7 @@ class MFLM(TemplateLM): # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable - ctx_len = (inplen + (logits.shape[0] - padding_len_inp)) + ctx_len = inplen + (logits.shape[0] - padding_len_inp) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab]