From 1ba6a3c0ab75292967d7869be30df3b29361caf2 Mon Sep 17 00:00:00 2001 From: senzhen Date: Fri, 14 Nov 2025 09:39:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dsafetensor=E6=9D=83=E9=87=8D?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E7=BB=AD=E8=AE=AD=E8=8B=A5=E5=B9=B2=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindformers/core/callback/callback.py | 148 ++++++------- mindformers/core/context/parallel.py | 2 +- .../dataloader/blended_megatron_dataloader.py | 4 +- .../dataloader/multi_source_dataloader.py | 104 +++++---- mindformers/models/modeling_utils.py | 17 +- mindformers/tools/__init__.py | 2 - .../ckpt_transform/transform_checkpoint.py | 119 ++++++---- mindformers/tools/resume_ckpt.py | 35 +-- mindformers/tools/utils.py | 208 ++---------------- mindformers/trainer/trainer.py | 53 +++-- mindformers/trainer/utils.py | 14 +- mindformers/utils/__init__.py | 12 +- mindformers/utils/file_utils.py | 199 +++++++++++++++++ mindformers/utils/load_checkpoint_utils.py | 82 +++---- mindformers/utils/process_utils.py | 30 +++ mindformers/utils/resume_ckpt_utils.py | 3 +- run_mindformer.py | 6 +- tests/st/test_safetensors/test_qkv_check.py | 4 +- toolkit/benchmarks/eval_with_harness.py | 16 +- 19 files changed, 592 insertions(+), 466 deletions(-) create mode 100644 mindformers/utils/file_utils.py create mode 100644 mindformers/utils/process_utils.py diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index 8ed839c04..effec93ba 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -61,6 +61,7 @@ from mindspore.profiler import ProfilerLevel, schedule from mindformers.tools import get_output_root_path from mindformers.tools.logger import logger from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.utils.process_utils import barrier_world from mindformers.tools.utils import ( get_output_subpath, get_real_rank, @@ -68,7 +69,6 @@ from mindformers.tools.utils import ( get_real_local_rank, get_pipeline_rank_ids, is_last_pipeline_stage, - barrier_world, get_ascend_log_path, set_safe_mode_for_file_or_dir ) @@ -80,6 +80,7 @@ from mindformers.parallel_core.training_graph.loss_func import ( check_device_local_loss ) +# pylint: disable=import-outside-toplevel __all__ = ['MFLossMonitor', 'CheckpointMonitor', 'SummaryMonitor', 'ProfileMonitor', 'EvalCallBack'] _cur_dir = os.getcwd() @@ -94,7 +95,7 @@ class AllReduceNet(Cell): """ def __init__(self, group_name): - super(AllReduceNet, self).__init__() + super().__init__() self.allreduce_sum = P.AllReduce(op=P.ReduceOp.SUM, group=group_name) self.add_flags(skip_auto_parallel_compile=True) @@ -206,7 +207,7 @@ class MFLossMonitor(Callback): gradient_accumulation_steps: int = 1, check_for_nan_in_loss_and_grad: bool = False, calculate_per_token_loss: bool = False): - super(MFLossMonitor, self).__init__() + super().__init__() self.per_print_times = per_print_times self.learning_rate = deepcopy(learning_rate) self.last_print_time = 0 @@ -476,7 +477,7 @@ class MFLossMonitor(Callback): global_step = cur_step_num + (cur_epoch_num - 1) * steps_per_epoch if self.mf_calculated: throughput_per_npu = self.full_model_flops / per_step_seconds / 1e9 - throughput_info = ', train_throughput_per_npu: %.3fT' % (throughput_per_npu) + throughput_info = f', train_throughput_per_npu: {throughput_per_npu:.3f}T' if self.tensor_writer is not None: self.tensor_writer.add_scalar('model-flops-throughput-per-npu', float(throughput_per_npu), @@ -484,33 +485,26 @@ class MFLossMonitor(Callback): else: throughput_info = '' + loss_str = f"{loss:5.3f}" if cb_params.dataset_sink_mode else \ + f"[{loss:5.3f}/{np.mean(self.loss_list):5.3f}]" if current_lr is not None: - if cb_params.dataset_sink_mode: - logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss: %5.3f, " - "per_step_time: %dms, lr: %s, overflow cond: %s, loss_scale: %s, global_norm: %s%s", - cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss, - int(per_step_seconds), current_lr, overflow, scaling_sens, global_norm, throughput_info) - else: - logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss:[%5.3f/%5.3f], " - "per_step_time: %dms, lr: %s, overflow cond: %s, loss_scale: %s, global_norm: %s%s", - cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss, np.mean(self.loss_list), - int(per_step_seconds), current_lr, overflow, scaling_sens, global_norm, throughput_info) + logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss: %s, " + "per_step_time: %dms, lr: %s, overflow cond: %s, loss_scale: %s, global_norm: %s%s", + cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss_str, + int(per_step_seconds), current_lr, overflow, scaling_sens, global_norm, throughput_info) if self.tensor_writer is not None: self.tensor_writer.add_scalar('learning-rate', float(current_lr), global_step=global_step) self.tensor_writer.add_scalar('learning-rate vs samples', float(current_lr), global_step=global_step * self.global_batch_size) else: - if cb_params.dataset_sink_mode: - logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss: %5.3f, " - "per_step_time: %dms, overflow cond: %s, loss_scale: %s, global_norm: %s%s", - cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss, - int(per_step_seconds), overflow, scaling_sens, global_norm, throughput_info) - else: - logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss:[%5.3f/%5.3f], " - "per_step_time: %dms, overflow cond: %s, loss_scale: %s, global_norm: %s%s", - cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss, np.mean(self.loss_list), - int(per_step_seconds), overflow, scaling_sens, global_norm, throughput_info) - show_str = ('|%%-%ds|' % 50) % (int(50 * percent / 100) * "█") + logger.info("{ Epoch:[%3d/%3d], step:[%5d/%5d], loss: %s, " + "per_step_time: %dms, overflow cond: %s, loss_scale: %s, global_norm: %s%s", + cur_epoch_num, origin_epochs, cur_step_num, steps_per_epoch, loss_str, + int(per_step_seconds), overflow, scaling_sens, global_norm, throughput_info) + + # print progress bar + bar = int(50 * percent / 100) * "█" + show_str = f"|{bar:<50}|" logger.info(" %4.1f%% %s %.5f samples/s/p %s }", percent, show_str, throughput, timedelta(seconds=int(time_remain))) if self.tensor_writer is not None: @@ -620,7 +614,7 @@ class TrainingStateMonitor(Callback): use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False): - super(TrainingStateMonitor, self).__init__() + super().__init__() if not (isinstance(step_interval, int) and step_interval > 0): logger.warning(f"The value of 'monitor_config.step_interval' should be positive integer, " f"but get {step_interval}. Use default value: 1.") @@ -777,15 +771,15 @@ class TrainingStateMonitor(Callback): parent_dirs = os.path.dirname(self.global_norm_record_path) if not os.path.exists(parent_dirs): os.makedirs(parent_dirs) - with open(self.global_norm_record_path, 'w') as file: + with open(self.global_norm_record_path, 'w', encoding="utf-8") as file: json.dump(self.abnormal_global_norms, file) set_safe_mode_for_file_or_dir(self.global_norm_record_path) logger.info(f"Current global norm {global_norm} is greater equal than " f"threshold {self.global_norm_spike_threshold}, stop training...") barrier_world() - logger.info(f"Call barrier before throw TREError.") + logger.info("Call barrier before throw TREError.") ms.runtime.synchronize() - logger.info(f"All stream execution completed.") + logger.info("All stream execution completed.") raise RuntimeError("TREError occurred......") self.abnormal_global_norms[str(global_step)].append(global_norm.item()) logger.info(f"The global norm {global_norm} of step {global_step} is still greater or equal " @@ -843,15 +837,15 @@ class TrainingStateMonitor(Callback): if self.print_struct is None: self.print_struct = False - if not (isinstance(self.target, list) and self.target and all([isinstance(i, str) for i in self.target])): - raise TypeError(f"The value of 'target' should be a list of str.") + if not (isinstance(self.target, list) and self.target and all(isinstance(i, str) for i in self.target)): + raise TypeError("The value of 'target' should be a list of str.") if not isinstance(self.invert, bool): - raise TypeError(f"The value of 'invert' should be bool.") + raise TypeError("The value of 'invert' should be bool.") if (self.throughput_baseline is not None and not (isinstance(self.throughput_baseline, (int, float)) and self.throughput_baseline > 0)): - raise ValueError(f"The value of 'throughput_baseline' should be None or positive number.") + raise ValueError("The value of 'throughput_baseline' should be None or positive number.") if not isinstance(self.print_struct, bool): - raise TypeError(f"The value of 'print_struct' should be bool.") + raise TypeError("The value of 'print_struct' should be bool.") attrs = ['local_norm_format', 'local_loss_format', 'device_local_norm_format', 'device_local_loss_format', 'optimizer_state_format', 'weight_state_format'] for attr in attrs: @@ -859,7 +853,7 @@ class TrainingStateMonitor(Callback): if self.global_norm_record_path and os.path.exists(self.global_norm_record_path): # the data format might be like {"300": [3.3], "600": [4.1, 4.2],} # because json cannot use number as key, we convert it to string - with open(self.global_norm_record_path, 'r') as file: + with open(self.global_norm_record_path, 'r', encoding="utf-8") as file: self.abnormal_global_norms = json.load(file) def _check_attr_formats(self, attr): @@ -940,7 +934,7 @@ class TrainingStateMonitor(Callback): continue data = np.load(os.path.join(self.dump_path, f), allow_pickle=False) if prefix == 'device_local_norm': - self._output(f'device_local_norm', data, self.dump_step, self.device_local_norm_format) + self._output('device_local_norm', data, self.dump_step, self.device_local_norm_format) elif prefix == 'local_loss': # collect all local loss if there are more than one local loss within one step local_losses[suffix] = local_losses.get(suffix, []) @@ -1208,7 +1202,7 @@ class CheckpointMonitor(ModelCheckpoint): self.embedding_size = embedding_size self.health_ckpts_record_dir = health_ckpts_record_dir - prefix = prefix + "_rank_{}".format(self.rank_id) + prefix = prefix + f"_rank_{self.rank_id}" self.global_batch_size = global_batch_size # this list records parameters which will be ignored when saving ckpt. @@ -1253,7 +1247,7 @@ class CheckpointMonitor(ModelCheckpoint): format=checkpoint_format, exception_save=exception_save, remove_redundancy=remove_redundancy) - super(CheckpointMonitor, self).__init__(prefix, ckpt_directory, config=config_ck) + super().__init__(prefix, ckpt_directory, config=config_ck) self.meta_json = os.path.join(self._directory, "meta.json") if self._config.async_save: self.last_epoch_num = None @@ -1305,8 +1299,8 @@ class CheckpointMonitor(ModelCheckpoint): keys = list(self.save_info_list.keys()) for record_step in keys: self.print_savetime(record_step, cb_params.batch_num) - if not any([self.save_info_list[record_step][key]['ckpt_file_path'] - for key in ['ckpt', 'network', 'trainable_params']]): + if not any(self.save_info_list[record_step][key]['ckpt_file_path'] + for key in ['ckpt', 'network', 'trainable_params']): self.save_info_list.pop(record_step) if self._config.async_save and not ms.async_ckpt_thread_status() and \ @@ -1395,11 +1389,11 @@ class CheckpointMonitor(ModelCheckpoint): } all_step_health_data = [] if os.path.exists(dump_health_json_path): - with open(dump_health_json_path, 'r') as file: + with open(dump_health_json_path, 'r', encoding="utf-8") as file: data = json.load(file) all_step_health_data = list(data) all_step_health_data.append(health_step_data) - with open(dump_health_json_path, 'w') as file: + with open(dump_health_json_path, 'w', encoding="utf-8") as file: json.dump(all_step_health_data, file, indent=4) set_safe_mode_for_file_or_dir(dump_health_json_path) @@ -1491,8 +1485,7 @@ class CheckpointMonitor(ModelCheckpoint): for rank in cur_dp: save_param_names = single_params.get(rank) if save_param_names == param_layout.keys(): - logger.warning( - f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.") + logger.warning("For remove_redundancy save checkpoint, the saved parameters are non-redundant.") param_layout_set = set(param_layout.keys()) if parallel_mode else set() cur_file = re.sub(r'rank_\d+', f'rank_{rank}', cur_file) self._tft_save_ckpt(param_layout_set, save_param_names, cur_file, append_dict, network) @@ -1540,8 +1533,7 @@ class CheckpointMonitor(ModelCheckpoint): save_param_names = single_params.get(rank_id) param_layout_set = set(param_layout.keys()) if save_param_names == param_layout.keys(): - logger.warning( - f"For remove_redundancy save checkpoint, the saved parameters are non-redundant.") + logger.warning("For remove_redundancy save checkpoint, the saved parameters are non-redundant.") def choice_func(x): return (x not in param_layout_set or (save_param_names is not None and x in save_param_names)) \ @@ -1697,7 +1689,7 @@ class ProfileMonitor(Callback): start_profile=True, profile_rank_ids=None, profile_pipeline=False, profile_communication=False, profile_memory=False, config=None, profiler_level=0, with_stack=False, data_simplification=True, mstx=False, **kwargs): - super(ProfileMonitor, self).__init__() + super().__init__() self.mstx_range_id = None self.mstx_enabled = is_version_ge(ms.__version__, '2.5.0') and not _check_mspti_is_on() self.stop_step = stop_step @@ -1723,14 +1715,14 @@ class ProfileMonitor(Callback): rank_id = get_real_rank() self.pipeline_rank_ids = get_pipeline_rank_ids() if self.profile_pipeline else None if self.pipeline_rank_ids == [-1]: - raise ValueError(f"Device num should be divided by pipeline stage num.") + raise ValueError("Device num should be divided by pipeline stage num.") if self._is_profile_required(rank_id): if not output_path: output_path = get_output_subpath('profile', rank_id) else: - output_path = os.path.join(output_path, 'profile', 'rank_{}'.format(rank_id)) - logger.info("Profile save path: %s", output_path) + output_path = os.path.join(output_path, 'profile', f'rank_{rank_id}') + logger.info(f"Profile save path: {output_path}") if ms.get_context("device_target") == "GPU" and profile_memory: logger.warning("The parameter profile_memory is not supported on GPU currently, " @@ -2005,7 +1997,7 @@ class ColdHotExpertMonitor(Callback): self.local_expert_num = self.expert_num // self.ep start_index = (self.rank_id // self.mp) * self.local_expert_num end_index = start_index + self.local_expert_num - self.local_expert_index = [i for i in range(start_index, end_index)] + self.local_expert_index = list(range(start_index, end_index)) self.rank_size = int(os.getenv("RANK_SIZE")) def on_train_step_end(self, run_context): @@ -2378,14 +2370,14 @@ class MoEDropRateCallback(Callback): if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) - logger.info("layer: %d, drop_rate: %.5f" % (i, droprate)) + logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") else: if hasattr(network.model.layers[i].feed_forward, "router"): fi = network.model.layers[i].feed_forward.router.router.fi_parameter.value() if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) - logger.info("layer: %d, drop_rate: %.5f" % (i, droprate)) + logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") def on_train_step_end(self, run_context): cb_params = run_context.original_args() @@ -2436,7 +2428,7 @@ class StressTestModelMonitor(Callback): check_stresslog_interval_time=60): logger.warning('StressTestModelMonitor serves as an experimental interface and its functionality is ' 'not yet stable.') - super(StressTestModelMonitor, self).__init__() + super().__init__() self.interval_steps = interval_steps self.last_checked_step = 0 @@ -2462,8 +2454,8 @@ class StressTestModelMonitor(Callback): if not isinstance(self.compare_interval_steps, int) or self.compare_interval_steps < 1: logger.warning(f"For StressTestMonitor, compare_interval_steps must be an integer greater than or equal" f" to 1, but got {self.compare_interval_steps}.") - logger.warning(f"Skipping interval steps comparison, only the last step result will be compared." - f" compare_interval_steps is set to None") + logger.warning("Skipping interval steps comparison, only the last step result will be compared." + " compare_interval_steps is set to None") self.compare_interval_steps = None self.stress_test_log_dir = stress_test_log_dir self.check_stresslog_interval_time = check_stresslog_interval_time @@ -2510,23 +2502,22 @@ class StressTestModelMonitor(Callback): log_file_path = os.path.join(saved_dir, "worker_0.log") # Start the subprocess command = shlex.split(command) - result_1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - # Monitor the log file - while result_1.poll() is None: # While the subprocess is running - time.sleep(self.check_stresslog_interval_time) - log_msg = self.readlog(log_file_path) - logger.info(f"Checking stress test log every {self.check_stresslog_interval_time} seconds") - logger.info(f"Current state of stress_test: {log_msg}") - - # Once the subprocess has finished, check the result - if result_1.returncode != 0: - logger.warning(f"An error occurred while running the stress test model on rank {rank_id}: \ - {result_1.stderr.read().decode('utf-8')}") - logger.warning(f"Check the sub task workers log for rank {rank_id} for more details.") + with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as result_1: + # Monitor the log file + while result_1.poll() is None: # While the subprocess is running + time.sleep(self.check_stresslog_interval_time) + log_msg = self.readlog(log_file_path) + logger.info(f"Checking stress test log every {self.check_stresslog_interval_time} seconds") + logger.info(f"Current state of stress_test: {log_msg}") + + # Once the subprocess has finished, check the result + if result_1.returncode != 0: + logger.warning(f"An error occurred while running the stress test model on rank {rank_id}: \ + {result_1.stderr.read().decode('utf-8')}") + logger.warning(f"Check the sub task workers log for rank {rank_id} for more details.") barrier() - logger.info(f"Stress tests ended, now starting to collect and compare results") + logger.info("Stress tests ended, now starting to collect and compare results") # If compare_interval_steps is None, only compare the last step result, and check for its validity. if not self.compare_interval_steps: @@ -2534,6 +2525,7 @@ class StressTestModelMonitor(Callback): "so only the last step result is compared.") else: interval_results = None + subtask_global_step_num = None logger.info(f"Test results are compared every {self.compare_interval_steps} steps") for i in range(self.worker_num): if get_rank() % self.worker_num == i: @@ -2568,7 +2560,7 @@ class StressTestModelMonitor(Callback): gathered_results, _ = all_gather_into_tensor(last_step_results) # gathered_results = gathered_results.asnumpy() # - logger.debug(f"Collected last step results are gathered_results.") + logger.debug("Collected last step results are gathered_results.") logger.info("Last step results are collected from each rank, now starting to compare last step results") rank0_result = gathered_results[0] @@ -2589,7 +2581,7 @@ class StressTestModelMonitor(Callback): """Extract the last step's results from the log file.""" loss_value = None global_norm_value = None - with open(log_file, 'r') as file: + with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in reversed(lines): if "INFO - {" in line: @@ -2605,7 +2597,7 @@ class StressTestModelMonitor(Callback): last_recorded_step = 0 results = [] steps_per_epoch = None - with open(log_file, 'r') as file: + with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in lines: if "INFO - {" in line: @@ -2675,7 +2667,7 @@ class StressTestModelMonitor(Callback): value_pairs = [val[1] for val in disc_values] logger.warning(f"STRESS TEST FAILED. DISCREPANCIES found in epoch {epoch_step[0]}, " f"step {epoch_step[1]}: ranks {indices}, (loss, global_norm) = {value_pairs}") - logger.warning(f"Check the workers log of the problematic rank for detailed results") + logger.warning("Check the workers log of the problematic rank for detailed results") return False def get_value_from_line(self, line, pattern): @@ -2689,7 +2681,7 @@ class StressTestModelMonitor(Callback): """ Search for the latest line indicating training has started, based on key identifiers. """ - with open(file_path, 'r', errors='ignore') as f: + with open(file_path, 'r', errors='ignore', encoding="utf-8") as f: lines = f.readlines() # Define the keywords indicating training has started @@ -2728,7 +2720,7 @@ class SDCMonitor(Callback): strike_num: int = 3, checksum_time: int = 5, checksum_cooldown_time: int = 180): - super(SDCMonitor, self).__init__() + super().__init__() logger.warning('SDCMonitor serves as an experimental interface and its functionality is not yet stable.') npu_asd_enable = int(os.getenv('NPU_ASD_ENABLE', '0')) @@ -2780,7 +2772,7 @@ class SDCMonitor(Callback): file_path = os.path.join(self.device_log_path, file) if not os.path.exists(file_path): continue - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding="utf-8") as f: logs = f.read() error_logs = self.silent_check_error_pattern.findall(logs) for log in error_logs: diff --git a/mindformers/core/context/parallel.py b/mindformers/core/context/parallel.py index 3496a43b2..d03f77c0d 100644 --- a/mindformers/core/context/parallel.py +++ b/mindformers/core/context/parallel.py @@ -23,7 +23,7 @@ from mindformers.modules.transformer.transformer import ( TransformerOpParallelConfig, ) from mindformers.tools.logger import logger -from mindformers.tools.utils import set_strategy_save_path +from mindformers.utils.file_utils import set_strategy_save_path from mindformers.trainer.config_args import ParallelConfig diff --git a/mindformers/dataset/dataloader/blended_megatron_dataloader.py b/mindformers/dataset/dataloader/blended_megatron_dataloader.py index 52502bb98..a41db2e20 100644 --- a/mindformers/dataset/dataloader/blended_megatron_dataloader.py +++ b/mindformers/dataset/dataloader/blended_megatron_dataloader.py @@ -33,9 +33,9 @@ from mindformers.tools.utils import ( get_dp_from_dataset_strategy, get_real_group_size, get_real_rank, - get_real_local_rank, - is_publicly_accessible_path + get_real_local_rank ) +from mindformers.utils.file_utils import is_publicly_accessible_path def is_dataset_built_on_rank() -> bool: diff --git a/mindformers/dataset/dataloader/multi_source_dataloader.py b/mindformers/dataset/dataloader/multi_source_dataloader.py index f12bc614a..24c055c23 100644 --- a/mindformers/dataset/dataloader/multi_source_dataloader.py +++ b/mindformers/dataset/dataloader/multi_source_dataloader.py @@ -24,10 +24,11 @@ from tqdm import tqdm from mindspore.dataset import Dataset, GeneratorDataset, Shuffle +from mindformers.utils.file_utils import is_publicly_accessible_path from .build_dataloader import build_dataset_loader from ...tools.logger import logger from ...tools.register import MindFormerRegister, MindFormerModuleType -from ...tools.utils import get_real_rank, is_publicly_accessible_path, get_device_num_per_node +from ...tools.utils import get_real_rank, get_device_num_per_node @MindFormerRegister.register(MindFormerModuleType.DATASET_LOADER) @@ -49,21 +50,7 @@ class MultiSourceDataLoader: shard_id = kwargs.pop("shard_id", None) if dataset_ratios is not None: - if any([ratios < 0 for ratios in dataset_ratios]): - raise ValueError( - f"the dataset_ratios should be a list of positive value, but got {dataset_ratios}") - - if abs(sum(dataset_ratios) - 1) > 1e-5: - raise ValueError("the sum of ratios is not equals to 1") - - if samples_count is None or samples_count <= 0: - raise ValueError(f"the samples_count should be a positive int when dataset_ratios is not None, " - f"but got {samples_count}") - - if not isinstance(dataset_ratios, list) or len(dataset_ratios) != len(sub_data_loader): - raise ValueError( - "the dataset_ratios should be a list and the length should equal to sub_data_loader") - + validate_dataset_ratios(dataset_ratios, sub_data_loader, samples_count) nums_per_dataset = [int(ratio * samples_count) for ratio in dataset_ratios] need_sample = True else: @@ -71,35 +58,13 @@ class MultiSourceDataLoader: nums_per_dataset = [] need_sample = False else: - if not isinstance(nums_per_dataset, list) or len(nums_per_dataset) != len(sub_data_loader): - raise ValueError( - "the nums_per_dataset should be a list and the length should equal to sub_data_loader") - - if any([num < 0 for num in nums_per_dataset]): - raise ValueError( - f"the nums_per_dataset should be a list of positive value, but got {nums_per_dataset}") - + validate_nums_per_dataset(nums_per_dataset, sub_data_loader) need_sample = True if sub_data_loader_args is None: - sub_data_loader_args = dict() + sub_data_loader_args = {} - if not isinstance(shuffle, bool) and shuffle.lower() not in ["global", "files", "infile"]: - raise ValueError( - f"shuffle should be a bool or a str and the value must be one of ['global', 'files', 'infile']") - - if isinstance(shuffle, bool): - shuffle_dataset = shuffle - shuffle_file = shuffle - elif shuffle == Shuffle.INFILE: - shuffle_dataset = False - shuffle_file = True - elif shuffle == Shuffle.FILES: - shuffle_dataset = True - shuffle_file = False - else: - shuffle_dataset = True - shuffle_file = True + shuffle_dataset, shuffle_file = get_shuffle_flags(shuffle) sub_data_loader_args["shuffle"] = shuffle_file dataset_loaders = [] @@ -175,6 +140,57 @@ class MultiSourceDataLoader: return dataset +def validate_dataset_ratios(dataset_ratios: list[float], sub_data_loader: list, samples_count: Optional[int]) -> None: + """Validates the validity of dataset ratio parameters for data distribution.""" + if any(ratios < 0 for ratios in dataset_ratios): + raise ValueError( + f"the dataset_ratios should be a list of positive value, but got {dataset_ratios}") + + if abs(sum(dataset_ratios) - 1) > 1e-5: + raise ValueError("the sum of ratios is not equals to 1") + + if samples_count is None or samples_count <= 0: + raise ValueError(f"the samples_count should be a positive int when dataset_ratios is not None, " + f"but got {samples_count}") + + if not isinstance(dataset_ratios, list) or len(dataset_ratios) != len(sub_data_loader): + raise ValueError( + "the dataset_ratios should be a list and the length should equal to sub_data_loader") + + +def validate_nums_per_dataset(nums_per_dataset: list[int], sub_data_loader: list) -> None: + """Validates the validity of per-dataset sample count parameters.""" + if not isinstance(nums_per_dataset, list) or len(nums_per_dataset) != len(sub_data_loader): + raise ValueError( + "the nums_per_dataset should be a list and the length should equal to sub_data_loader") + + if any(num < 0 for num in nums_per_dataset): + raise ValueError( + f"the nums_per_dataset should be a list of positive value, but got {nums_per_dataset}") + + +def get_shuffle_flags(shuffle): + """Resolves shuffle configuration into dataset-level and file-level shuffle flags.""" + if not isinstance(shuffle, bool) and shuffle.lower() not in ["global", "files", "infile"]: + raise ValueError( + "shuffle should be a bool or a str and the value must be one of ['global', 'files', 'infile']") + + if isinstance(shuffle, bool): + shuffle_dataset = shuffle + shuffle_file = shuffle + elif shuffle == Shuffle.INFILE: + shuffle_dataset = False + shuffle_file = True + elif shuffle == Shuffle.FILES: + shuffle_dataset = True + shuffle_file = False + else: + shuffle_dataset = True + shuffle_file = True + + return shuffle_dataset, shuffle_file + + def prepare_generator_sub_dataloader_args(class_name, full_args): cls_obj = MindFormerRegister.get_cls(module_type="dataset_loader", class_name=class_name) @@ -352,15 +368,15 @@ class MultiSourceRandomAccessDataset: if os.path.exists(save_indices_npz_path): raise ValueError(f"The save_indices_npz_path {save_indices_npz_path} has existed.") if is_publicly_accessible_path(save_indices_npz_path): - logger.info(f".......... npz file is saved in shared path ..........") + logger.info(".......... npz file is saved in shared path ..........") if self.rank_id == 0: logger.info(f".......... save indices to npz file: {save_indices_npz_path} by rank_{self.rank_id} ." f".........") np.savez_compressed(save_indices_npz_path, dataloader_index=self.dataloader_index, data_sample_index=self.data_sample_index) else: - logger.warning(f"If the npz file is being saved to a shared path, please add this path to the " - f"environment variable SHARED_PATHS, otherwise, please ignore this warning.") + logger.warning("If the npz file is being saved to a shared path, please add this path to the " + "environment variable SHARED_PATHS, otherwise, please ignore this warning.") if self.rank_id % get_device_num_per_node() == 0: logger.info(f".......... save indices to npz file: {save_indices_npz_path} by rank_{self.rank_id} ." f".........") diff --git a/mindformers/models/modeling_utils.py b/mindformers/models/modeling_utils.py index 8e3e306e7..14eb64901 100644 --- a/mindformers/models/modeling_utils.py +++ b/mindformers/models/modeling_utils.py @@ -28,6 +28,8 @@ import mindspore as ms from mindspore import nn from mindspore import load_checkpoint, load_param_into_net from mindspore import context, Model + +import mindformers.models.auto as auto_module from mindformers.tools.check_rules import check_yaml_depth_before_loading from mindformers.tools.hub import ( PushToHubMixin, @@ -46,11 +48,10 @@ from mindformers.tools.ckpt_transform import TransformCkpt, make_soft_link from mindformers.tools.utils import ( get_real_rank, get_output_root_path, - clear_auto_trans_output, - remake_folder, - barrier_world, FILE_PERMISSION ) +from mindformers.utils.file_utils import clear_auto_trans_output, remake_folder +from mindformers.utils.process_utils import barrier_world from mindformers.models.utils import DEFAULT_CHECKPOINT_SAVE_FOLDER from ..mindformer_book import MindFormerBook, print_path_or_list from ..tools.utils import try_sync_file, replace_tk_to_mindpet @@ -513,7 +514,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): if ( filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) - and filename not in shards.keys() + and filename not in shards and is_main_process and reg.fullmatch(filename_no_suffix) is not None ): @@ -589,7 +590,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): meraged_dict = {} if os.path.exists(config_path): - with open(config_path, 'r') as file_reader: + with open(config_path, 'r', encoding='utf-8') as file_reader: check_yaml_depth_before_loading(file_reader) file_reader.seek(0) meraged_dict = yaml.safe_load(file_reader.read()) @@ -688,7 +689,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): config_args = MindFormerConfig(yaml_file) kwargs["checkpoint_name_or_path"] = kwargs.get("checkpoint_name_or_path") \ - if "checkpoint_name_or_path" in kwargs.keys() else ckpt_file + if "checkpoint_name_or_path" in kwargs else ckpt_file config_args.model.model_config.update(**kwargs) logger.info("model config: %s and checkpoint_name_or_path: %s are used for " "model building.", yaml_file, config_args.model.model_config.checkpoint_name_or_path) @@ -731,7 +732,7 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): try_sync_file(yaml_file) config_args = MindFormerConfig(yaml_file) kwargs["checkpoint_name_or_path"] = kwargs.get("checkpoint_name_or_path") \ - if "checkpoint_name_or_path" in kwargs.keys() else pretrained_model_name_or_dir + if "checkpoint_name_or_path" in kwargs else pretrained_model_name_or_dir config_args.model.model_config.update(**kwargs) return config_args @@ -1440,8 +1441,6 @@ class PreTrainedModel(nn.Cell, ModelMixin, GenerationMixin, PushToHubMixin): if not isinstance(auto_class, str): auto_class = auto_class.__name__ - import mindformers.models.auto as auto_module - if not hasattr(auto_module, auto_class): raise ValueError(f"{auto_class} is not a valid auto class.") diff --git a/mindformers/tools/__init__.py b/mindformers/tools/__init__.py index 8f4a2733a..b4d4eefa0 100644 --- a/mindformers/tools/__init__.py +++ b/mindformers/tools/__init__.py @@ -34,8 +34,6 @@ from .utils import ( count_params, get_output_root_path, get_output_subpath, - set_output_path, - set_strategy_save_path, str2bool, calculate_pipeline_stage, is_last_pipeline_stage diff --git a/mindformers/tools/ckpt_transform/transform_checkpoint.py b/mindformers/tools/ckpt_transform/transform_checkpoint.py index 735566178..f63587358 100644 --- a/mindformers/tools/ckpt_transform/transform_checkpoint.py +++ b/mindformers/tools/ckpt_transform/transform_checkpoint.py @@ -30,13 +30,15 @@ from mindformers.tools.utils import ( get_output_root_path, get_remote_save_url, get_device_num_per_node, + is_main_rank, + format_path +) +from mindformers.utils.file_utils import ( create_file, delete_file, - remake_folder, - is_main_rank, - format_path, - barrier_world + remake_folder ) +from mindformers.utils.process_utils import barrier_world from mindformers.tools.logger import logger from mindformers.tools.cloud_adapter import mox_adapter from mindformers.tools.ckpt_transform.utils import ( @@ -51,6 +53,8 @@ from mindformers.tools.ckpt_transform.utils import ( if check_in_modelarts(): import moxing as mox +__all__ = ['TransformCkpt'] + class TransformCkpt: """Transform src_checkpoint from src_strategy to dst_strategy.""" @@ -105,37 +109,9 @@ class TransformCkpt: self.node_num = self.world_size // self.npu_num_per_node if not is_power_of_two(self.npu_num_per_node): raise ValueError( - f"The `npu_num_per_node` must be a power of 2, but get {npu_num_per_node}") + f"The `npu_num_per_node` must be a power of 2, but get {self.npu_num_per_node}") - # Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size. - if transform_process_num < 1: - raise ValueError("transform_process_num should not smaller than 1," - f"but got {transform_process_num}.") - if transform_process_num > self.world_size: - logger.warning("transform_process_num: %d should not bigger than world_size: %d. \ - transform_process_num is set to %d.", - transform_process_num, self.world_size, self.world_size) - transform_process_num = self.world_size - if self.world_size % transform_process_num != 0: - raise ValueError(f"transform_process_num: {transform_process_num} " - f"should be divided by world_size: {self.world_size}.") - if check_in_modelarts() and 1 < transform_process_num < self.node_num: - logger.warning("transform_process_num: %d should not smaller than \ - node_num = world_size // npu_num_per_node = %d when training on AICC. \ - transform_process_num is set to node num = %d", - transform_process_num, self.node_num, self.node_num) - transform_process_num = self.world_size // npu_num_per_node - if check_in_modelarts() and transform_process_num == 1: - # The 0th NPU of each node is responsible for transform all checkpoints. - # For example, if world_size=16 and npu_num_per_node=8, - # then transform_rank_id_list=[0,8]. - self.transform_rank_id_list = [i for i in range(0, self.world_size, self.npu_num_per_node)] - else: - # Obtain transform_rank_id_list. For example, if world_size=8 and transform_process_num=2, - # then transform_rank_id_list=[0,4], means that the 0th rank and the 4th rank - # responsible for transform checkpoints. - self.transform_rank_id_list = \ - [i for i in range(0, self.world_size, self.world_size // transform_process_num)] + self.transform_rank_id_list = self._get_transform_rank_id_list(transform_process_num) self.transform_process_num = len(self.transform_rank_id_list) if auto_trans_ckpt: @@ -376,10 +352,10 @@ class TransformCkpt: raise ValueError(f"The checkpoint of rank_{src_rank_id} is not found!") checkpoint_file_list = sorted(checkpoint_file_list, key=os.path.getmtime) checkpoint_file_map[src_rank_id] = checkpoint_file_list[-1] - save_checkpoint_dir = os.path.join(dst_checkpoint, "rank_{}".format(current_transform_rank_id)) + save_checkpoint_dir = os.path.join(dst_checkpoint, f"rank_{current_transform_rank_id}") os.makedirs(save_checkpoint_dir, exist_ok=True) save_checkpoint_path = os.path.join(save_checkpoint_dir, - "{}.ckpt".format(prefix + str(current_transform_rank_id))) + f"{prefix + str(current_transform_rank_id)}.ckpt") logger.info("rank_list: %s", src_rank_list) logger.info("checkpoint_file_map: %s", checkpoint_file_map) logger.info("save_checkpoint_path: %s", save_checkpoint_path) @@ -451,7 +427,7 @@ class TransformCkpt: if rank_id: merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy_by_rank_{rank_id}.ckpt') else: - merge_path = os.path.join(strategy_path, f'merged_ckpt_strategy.ckpt') + merge_path = os.path.join(strategy_path, 'merged_ckpt_strategy.ckpt') merged_succeed_txt = os.path.join(strategy_path, "merge_succeed.txt") if self.is_main_rank: @@ -515,6 +491,67 @@ class TransformCkpt: return dst_strategy + def _get_transform_rank_id_list(self, transform_process_num): + """ + Generates a list of distributed rank IDs responsible for checkpoint transformation. + + This internal method calculates and returns the list of rank IDs assigned to handle data/checkpoint + transformation in a distributed training environment. It first validates input constraints for + `transform_process_num`, applies environment-specific adjustments (for ModelArts/AICC platforms), + and then computes the rank IDs based on cluster configuration (world size, per-node NPU count). + + The rank ID assignment follows two core strategies: + 1. ModelArts/AICC environment with `transform_process_num=1`: Assigns the 0th NPU of each node + (e.g., world_size=16, npu_num_per_node=8 → rank IDs [0, 8]) + 2. Other scenarios: Distributes rank IDs evenly across the total world size + (e.g., world_size=8, transform_process_num=2 → rank IDs [0, 4]) + + Args: + transform_process_num (int): Number of processes allocated for transformation tasks. + Must be a positive integer that divides `self.world_size` (after potential clamping). + + Returns: + list[int]: Sorted list of rank IDs designated for transformation. + The length of the list equals the final `transform_process_num` (after adjustments). + + Raises: + ValueError: If: + - `transform_process_num` is less than 1 + - `self.world_size` is not divisible by `transform_process_num` (before platform-specific adjustments) + """ + # Before obtaining transform_rank_id_list, check 1 ≤ transform_process_num ≤ world_size. + if transform_process_num < 1: + raise ValueError("transform_process_num should not smaller than 1," + f"but got {transform_process_num}.") + if transform_process_num > self.world_size: + logger.warning(f"transform_process_num: {transform_process_num} should not " + f"bigger than world_size: {self.world_size}. " + f"transform_process_num is set to {self.world_size}.") + transform_process_num = self.world_size + if self.world_size % transform_process_num != 0: + raise ValueError(f"transform_process_num: {transform_process_num} " + f"should be divided by world_size: {self.world_size}.") + + if check_in_modelarts() and 1 < transform_process_num < self.node_num: + logger.warning("transform_process_num: %d should not smaller than \ + node_num = world_size // npu_num_per_node = %d when training on AICC. \ + transform_process_num is set to node num = %d", + transform_process_num, self.node_num, self.node_num) + transform_process_num = self.world_size // self.npu_num_per_node + + if check_in_modelarts() and transform_process_num == 1: + # The 0th NPU of each node is responsible for transform all checkpoints. + # For example, if world_size is 16 and npu_num_per_node is 8, then transform_rank_id_list should be [0,8]. + transform_rank_id_list = list(range(0, self.world_size, self.npu_num_per_node)) + else: + # Obtain transform_rank_id_list. For example, + # if world_size is 8 and transform_process_num is 2, then transform_rank_id_list should be [0,4]. + # which means that the 0th rank and the 4th rank responsible for transform checkpoints. + transform_rank_id_list = list(range(0, self.world_size, self.world_size // transform_process_num)) + + return transform_rank_id_list + + @staticmethod def check_src_checkpoint_and_strategy(src_checkpoint, src_strategy): """check src checkpoint and strategy""" @@ -598,12 +635,12 @@ class TransformCkpt: if check_in_modelarts(): transformed_ckpt_dir_obs = os.path.join(self.transformed_checkpoint_dir_obs, os.path.basename(ckpt_dir)) transform_failed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs, - f'transform_failed_rank_*.txt')) + 'transform_failed_rank_*.txt')) transform_succeed_txts = mox.file.glob(os.path.join(transformed_ckpt_dir_obs, - f'transform_succeed_rank_*.txt')) + 'transform_succeed_rank_*.txt')) else: - transform_failed_txts = glob(os.path.join(ckpt_dir, f'transform_failed_rank_*.txt')) - transform_succeed_txts = glob(os.path.join(ckpt_dir, f'transform_succeed_rank_*.txt')) + transform_failed_txts = glob(os.path.join(ckpt_dir, 'transform_failed_rank_*.txt')) + transform_succeed_txts = glob(os.path.join(ckpt_dir, 'transform_succeed_rank_*.txt')) if transform_failed_txts: raise ValueError(f"Transform failed, find {transform_failed_txts}.") current_count = len(transform_succeed_txts) diff --git a/mindformers/tools/resume_ckpt.py b/mindformers/tools/resume_ckpt.py index 7a577178a..8887e3049 100644 --- a/mindformers/tools/resume_ckpt.py +++ b/mindformers/tools/resume_ckpt.py @@ -18,26 +18,29 @@ import os from mindformers.tools.logger import logger from mindformers.tools.utils import ( - create_file, check_in_modelarts, check_ckpt_file_name, get_real_rank, - get_real_group_size, get_epoch_and_step_from_ckpt_name, get_times_epoch_and_step_from_ckpt_name, get_rank_id_from_ckpt_name, get_remote_save_url, replace_rank_id_in_ckpt_name, - remake_folder, - is_publicly_accessible_path, get_output_root_path, - barrier_world, Validator ) +from mindformers.utils.file_utils import ( + create_file, + remake_folder, + is_publicly_accessible_path +) +from mindformers.utils.process_utils import barrier_world from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existed_in_sf_dir if check_in_modelarts(): import moxing as mox +else: + mox = None NO_META = "FOUND NO META.JSON" @@ -97,7 +100,7 @@ def get_resume_checkpoint_by_meta(checkpoint_dir, ckpt_format='ckpt', if last_epoch is None or last_step is None or last_ckpt_file is None: logger.info("No meta.json available and will use the checkpoints " "from the last timestamp for resume training.") - check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format) + check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format) create_file(latest_checkpointed_iteration_txt, NO_META) resume_ckpt = True else: @@ -135,7 +138,7 @@ def checkpoint_health_monitor(health_ckpts_record_dir, resume_ckpt_list): health_ckpts = [] health_monitoring_json = os.path.join(health_ckpts_record_dir, 'health_ckpts.json') if os.path.exists(health_monitoring_json): - with open(health_monitoring_json, "r") as json_file: + with open(health_monitoring_json, "r", encoding="utf-8") as json_file: json_data = json.load(json_file) if not json_data: logger.warning(f"Get nothing from {json_file}.") @@ -165,7 +168,7 @@ def get_resume_ckpt(latest_checkpointed_iteration_txt, rank_id): if not check_in_modelarts(): if not os.path.exists(latest_checkpointed_iteration_txt): raise ValueError(f"Can not find {latest_checkpointed_iteration_txt}") - with open(latest_checkpointed_iteration_txt, 'r') as f: + with open(latest_checkpointed_iteration_txt, 'r', encoding='utf-8') as f: resume_info = [line.strip() for line in f.readlines()] else: if not mox.file.exists(latest_checkpointed_iteration_txt): @@ -199,7 +202,7 @@ def get_info_from_meta(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): if not os.path.exists(meta_json): logger.warning("%s is not found.", meta_json) continue - with open(meta_json, "r") as json_file: + with open(meta_json, "r", encoding='utf-8') as json_file: try: meta_data = json.load(json_file) if not meta_data: @@ -231,7 +234,7 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, epoch and step are consistent, and the path exists. """ # get all valid ckpts where the epoch and step values are not greater than those of last_ckpt_file. - ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-")] + ckpt_prefix = last_ckpt_file[:last_ckpt_file.rfind("-") + 1] last_epoch, last_step = get_epoch_and_step_from_ckpt_name(last_ckpt_file, ckpt_format) original_rank = get_rank_id_from_ckpt_name(last_ckpt_file) valid_ckpts = {} @@ -254,9 +257,9 @@ def get_resume_ckpt_list(checkpoint_dir, last_ckpt_file, rank_id, rank_dir_num, # get ckpts suitable for resuming, where their rank numbers are intact, # epoch and step are consistent, and the path exists. resume_ckpt_list = [] - for key in valid_ckpts: - if check_checkpoints_by_rank(valid_ckpts[key], rank_dir_num): - ckpt_file = replace_rank_id_in_ckpt_name(valid_ckpts[key][0], rank_id) + for key, ckpts in valid_ckpts.items(): + if check_checkpoints_by_rank(ckpts, rank_dir_num): + ckpt_file = replace_rank_id_in_ckpt_name(ckpts[0], rank_id) resume_ckpt = os.path.join(checkpoint_dir, f"rank_{rank_id}", ckpt_file) if not os.path.exists(resume_ckpt): raise FileNotFoundError(f"{resume_ckpt} is not found!") @@ -311,14 +314,14 @@ def check_meta_info(epoch, step, ckpt_file, meta_json, ckpt_format='ckpt'): return True -def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): +def check_last_timestamp_checkpoints(checkpoint_dir, rank_dir_num, ckpt_format='ckpt'): """ Verify that the prefix, epoch and step of the checkpoints from the last timestamp are equal across all rank folders in the checkpoint_dir directory. """ compared_checkpoint_name = None compared_original_checkpoint_name = None - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir, ckpt_format) if not last_checkpoint: @@ -334,7 +337,7 @@ def check_last_timestamp_checkpoints(checkpoint_dir, ckpt_format='ckpt'): return find_diff_ckpt = False - for rank_id_tmp in range(get_real_group_size()): + for rank_id_tmp in range(rank_dir_num): checkpoint_rank_dir = os.path.join(checkpoint_dir, f"rank_{rank_id_tmp}") last_checkpoint = get_last_checkpoint(checkpoint_rank_dir) diff --git a/mindformers/tools/utils.py b/mindformers/tools/utils.py index f250d48ab..396cec3af 100644 --- a/mindformers/tools/utils.py +++ b/mindformers/tools/utils.py @@ -16,7 +16,6 @@ import json import os import re -import shutil import stat import tempfile from multiprocessing import Process @@ -33,8 +32,8 @@ except ImportError: import mindspore as ms from mindspore import Tensor, context -from mindspore._checkparam import args_type_check -from mindspore.communication import get_group_size, get_rank, comm_func, get_local_rank +from mindspore.communication import get_group_size, get_rank, get_local_rank + PARALLEL_MODE = {'DATA_PARALLEL': context.ParallelMode.DATA_PARALLEL, 'SEMI_AUTO_PARALLEL': context.ParallelMode.SEMI_AUTO_PARALLEL, @@ -85,7 +84,7 @@ class Validator: def check_type(arg_value, arg_type): """Check int.""" if not isinstance(arg_value, arg_type): - raise TypeError('{} should be {} type, but get {}'.format(arg_value, arg_type, type(arg_value))) + raise TypeError(f'{arg_value} should be {arg_type} type, but get {type(arg_value)}') @staticmethod def is_obs_url(url): @@ -96,11 +95,11 @@ class Validator: def check_obs_url(url): """Check obs url.""" if not isinstance(url, str): - raise TypeError('remote_save_url type should be a str, but get {}, ' - 'please check your remote_save_url config'.format(type(url))) + raise TypeError(f'remote_save_url type should be a str, but get {type(url)}, ' + 'please check your remote_save_url config') if not (url.startswith(_PROTOCOL + '://') or url.startswith(_PROTOCOL_S3 + '://')): raise TypeError('remote_save_url should be start with obs:// or s3://, ' - 'but get {}, please check your remote_save_url config'.format(url)) + f'but get {url}, please check your remote_save_url config') def check_list(var_name: str, list_var: Union[Tuple, List], num: int): @@ -116,7 +115,7 @@ def check_list(var_name: str, list_var: Union[Tuple, List], num: int): """ for value in list_var: if value >= num: - raise ValueError('The index of the {} needs to be less than the number of nodes {}.'.format(var_name, num)) + raise ValueError(f'The index of the {var_name} needs to be less than the number of nodes {num}.') def check_file(file_path, file_type=None): @@ -154,33 +153,6 @@ def get_output_root_path(): return os.path.realpath(expanduser_path) -@args_type_check(path=str) -def set_output_path(path): - """set output path""" - from .logger import logger - if path is None: - path = './output' - expanduser_path = os.path.expanduser(path) - os.environ['LOCAL_DEFAULT_PATH'] = os.path.realpath(expanduser_path) - logger.info(f"set output path to '{os.path.realpath(expanduser_path)}'") - - -def set_strategy_save_path(config): - """set strategy path""" - from .logger import logger - rank_id = get_real_rank() - strategy_ckpt_save_dir = os.path.join(get_output_root_path(), "strategy") - os.makedirs(strategy_ckpt_save_dir, exist_ok=True) - set_safe_mode_for_file_or_dir(strategy_ckpt_save_dir) - - strategy_ckpt_save_file = config.get('strategy_ckpt_save_file', "ckpt_strategy.ckpt") - if not strategy_ckpt_save_file.endswith(f"_rank_{rank_id}.ckpt"): - strategy_name = os.path.basename(strategy_ckpt_save_file).replace(".ckpt", f"_rank_{rank_id}.ckpt") - config['strategy_ckpt_save_file'] = os.path.join(strategy_ckpt_save_dir, strategy_name) - context.set_auto_parallel_context(strategy_ckpt_save_file=config['strategy_ckpt_save_file']) - logger.info(f"set strategy path to '{config['strategy_ckpt_save_file']}'") - - def get_log_path(): path = os.getenv("LOG_MF_PATH", os.path.join(get_output_root_path(), "log")) return os.path.expanduser(path) @@ -192,7 +164,7 @@ def get_output_subpath(sub_class, rank_id=0, append_rank=True): root_path = get_output_root_path() directory = os.path.join(root_path, sub_class) if append_rank: - directory = os.path.join(directory, 'rank_{}'.format(rank_id)) + directory = os.path.join(directory, f'rank_{rank_id}') return format_path(directory) @@ -238,12 +210,13 @@ def get_num_nodes_devices(rank_size: int) -> Tuple[int, int]: num_nodes (int): number of nodes. num_devices (int): number of devices. """ - if rank_size in (2, 4, 8): + device_num_per_node = get_device_num_per_node() + if rank_size <= device_num_per_node: num_nodes = 1 num_devices = rank_size else: - num_nodes = rank_size // 8 - num_devices = 8 + num_nodes = rank_size // device_num_per_node + num_devices = device_num_per_node return num_nodes, num_devices @@ -253,9 +226,9 @@ class Const: def __setattr__(self, key, value): if key in self.__dict__: - raise PermissionError('Can not change const {0}.'.format(key)) + raise PermissionError(f'Can not change const {key}.') if not key.isupper(): - raise ValueError('Const name {0} is not all uppercase.'.format(key)) + raise ValueError(f'Const name {key} is not all uppercase.') self.__dict__[key] = value @@ -330,7 +303,7 @@ def count_params(net): def try_sync_file(file_name): """If the file is still downloading, we need to wait before the file finished downloading""" if fcntl: - with open(file_name, 'r') as fp: + with open(file_name, 'r', encoding='utf-8') as fp: fcntl.flock(fp.fileno(), fcntl.LOCK_EX) @@ -382,16 +355,14 @@ def parse_value(value): b = int(a) except (TypeError, ValueError): return False - else: - return a == b + return a == b def isfloat(x): try: float(x) except (TypeError, ValueError): return False - else: - return True + return True def isbool(x): return x in ["True", "False"] @@ -401,8 +372,7 @@ def parse_value(value): json.loads(x) except json.decoder.JSONDecodeError: return False - else: - return True + return True if isint(value): return int(value) @@ -471,7 +441,7 @@ def get_dp_from_dataset_strategy(): first_input_stra = data_strategy[0] dp = int(first_input_stra[0]) else: - raise TypeError(f"Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") + raise TypeError("Dataset_strategy in mindspore auto parallel context is invalid, only support (tuple, list)") return dp @@ -505,116 +475,6 @@ def is_last_pipeline_stage(): return (rank // device_num_per_stage + 1) == stage_num -def is_publicly_accessible_path(path): - """Check a path is accessible by all rank.""" - from .logger import logger - if get_real_group_size() <= get_device_num_per_node(): - return True - - if check_in_modelarts(): - return True - - if check_shared_disk(path): - return True - - # For example, SHARED_PATHS="/mnt/shared1,/mnt/shared2" - shared_paths = os.getenv("SHARED_PATHS", "").split(',') - path = os.path.realpath(path) - for shared_path in shared_paths: - if not shared_path: - continue - shared_path = os.path.realpath(shared_path) - if path.startswith(shared_path): - return True - logger.info("System can not identify if given path is shared disk. " - "If it is, Please set env 'SHARED_PATHS' to given path.") - return False - - -def create_file(file_path, info=None): - """create file.""" - if Validator.is_obs_url(file_path): - if not check_in_modelarts(): - raise ValueError(f"When create {file_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - with mox.file.File(file_path, 'w') as f: - if info: - if isinstance(info, list): - for sub_info in info: - f.write(str(sub_info) + "\n") - else: - f.write(info) - else: - f.write("Hugging ModelArts.") - else: - flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC - with os.fdopen(os.open(file_path, flags_, FILE_PERMISSION), 'w') as f: - if info: - if isinstance(info, list): - for sub_info in info: - f.write(str(sub_info) + "\n") - else: - f.write(info) - - -def delete_file(file_path): - """delete file""" - if Validator.is_obs_url(file_path): - if not check_in_modelarts(): - raise ValueError(f"When create {file_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if mox.file.exists(file_path): - mox.file.remove(file_path, recursive=False) - else: - if os.path.exists(file_path): - os.remove(file_path) - - -def remake_folder(folder_path, permissions=None): - """make folder""" - from .logger import logger - rank_id = get_real_rank() - logger.info("Remake %s...", folder_path) - if Validator.is_obs_url(folder_path): - if not check_in_modelarts(): - raise ValueError(f"When remaking {folder_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if not rank_id: - if mox.file.exists(folder_path): - mox.file.remove(folder_path, recursive=True) - mox.file.make_dirs(folder_path) - logger.info("OBS: Folder %s is remaked.", folder_path) - else: - if is_main_rank(): - if os.path.exists(folder_path): - shutil.rmtree(folder_path) - os.makedirs(folder_path, exist_ok=True) - os.chmod(folder_path, permissions) - logger.info("Folder %s is remaked.", folder_path) - - -def remove_folder(folder_path, rank_id=None): - """delete folder""" - from .logger import logger - rank_id = rank_id or get_real_rank() - logger.info("Remove %s...", folder_path) - if Validator.is_obs_url(folder_path): - if not check_in_modelarts(): - raise ValueError(f"When removing {folder_path}, \ - it is detected that it is not in the ModelArts platform.") - import moxing as mox - if mox.file.exists(folder_path) and not rank_id: - mox.file.remove(folder_path, recursive=True) - logger.info("OBS: Folder %s is removed.", folder_path) - else: - if os.path.exists(folder_path) and is_main_rank(): - shutil.rmtree(folder_path) - logger.info("Folder %s is removed.", folder_path) - - def set_safe_mode_for_file_or_dir(path): if isinstance(path, str): path = [path] @@ -681,22 +541,6 @@ def replace_rank_id_in_ckpt_name(ckpt_file, dst_rank_id): return ckpt_name -def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None): - """clear transformed_checkpoint and strategy""" - folder_list = ["strategy", "transformed_checkpoint"] - for folder in folder_list: - if check_in_modelarts(): - folder_path = os.path.join(get_remote_save_url(), folder) - else: - folder_path = os.path.join(get_output_root_path(), folder) - if os.path.realpath(folder_path) in (load_checkpoint, src_strategy_path_or_dir): - raise ValueError( - "./transformed_checkpoint or ./strategy with given config.output_dir is same as " - "load_checkpoint or src_strategy_path_or_dir which is not allowed when auto_trans is True." - "Please move it to a different location or specify a different output folder.") - remake_folder(folder_path, permissions=0o750) - - def check_ckpt_file_name(ckpt_file, ckpt_fmt='ckpt'): """Check ckpt name in the format of {prefix}-{epoch}_{step}.ckpt""" ckpt_name = os.path.split(ckpt_file)[1] @@ -731,18 +575,6 @@ def is_pynative(): return enforce_eager.lower() == "true" -def barrier_world(action: str = None): - """barrier all rank until action is done""" - if get_real_group_size() > 1: - from .logger import logger - if action is not None: - logger.info("Wait " + str(action)) - else: - logger.info("Now barriered...") - - comm_func.barrier() - - def get_pipeline_rank_ids(): """Calculate rank id of each stage and return a list of first rank id in each stage. @@ -765,7 +597,7 @@ def get_pipeline_rank_ids(): def ensure_divisibility(numerator, denominator): """Ensure that numerator is divisible by the denominator.""" if numerator % denominator != 0: - raise ValueError("{} is not divisible by {}".format(numerator, denominator)) + raise ValueError(f"{numerator} is not divisible by {denominator}") def divide(numerator, denominator): diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 5f1dc053b..ca73d0ae4 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -42,18 +42,20 @@ from mindformers.models import PreTrainedModel, BaseImageProcessor, \ PreTrainedTokenizerBase, BaseAudioProcessor from mindformers.models.utils import WEIGHTS_NAME from mindformers.tools.utils import ( - set_output_path, - set_strategy_save_path, check_in_modelarts, get_real_rank, get_real_group_size, set_remote_save_url, get_output_root_path, get_device_num_per_node, - is_publicly_accessible_path, - clear_auto_trans_output, FILE_PERMISSION ) +from mindformers.utils.file_utils import ( + set_output_path, + set_strategy_save_path, + is_publicly_accessible_path, + clear_auto_trans_output +) from mindformers.tools.logger import logger from mindformers.tools.register import MindFormerConfig from mindformers.tools.register.config import ordered_yaml_dump @@ -217,7 +219,7 @@ class Trainer: # check_task_and_model if self.task not in SUPPORT_TASKS.keys(): raise ValueError( - "The value of task must be in {}, but get {}".format(SUPPORT_TASKS.keys(), self.task)) + f"The value of task must be in {SUPPORT_TASKS.keys()}, but get {self.task}") if isinstance(self.model, (Cell, PreTrainedModel)): logger.info("The model instance has been entered, " @@ -230,7 +232,7 @@ class Trainer: else: self.is_model_instance = False if isinstance(self.model, str): - self.model = self.model + '_{}'.format(self.pet_method) if self.pet_method else self.model + self.model = f"{self.model}_{self.pet_method}" if self.pet_method else self.model if self.model not in SUPPORT_MODEL_NAMES: raise ValueError(f"model must be in {SUPPORT_MODEL_NAMES} " f"when model's type is string, but got {self.model}.") @@ -295,7 +297,7 @@ class Trainer: if (self.config.auto_trans_ckpt or self.config.resume_training) and \ check_in_modelarts() and self.config.remote_save_url: set_remote_save_url(self.config.remote_save_url) - logger.info(f"Set remote_save_url: %s, the output file will be uploaded to here.", + logger.info("Set remote_save_url: %s, the output file will be uploaded to here.", self.config.remote_save_url) # build trainer @@ -1010,7 +1012,7 @@ class Trainer: logger.warning("The `use_past` is set to False and reinit the incoming model.") model_config.parallel_config = self.config.parallel_config model_config.moe_config = self.config.moe_config - self.model.__init__(model_config) + self.model.__class__.__init__(self.model, model_config) # pylint: disable=unnecessary-dunder-call self.reset_model = False @staticmethod @@ -1038,7 +1040,7 @@ class Trainer: """get last checkpoint for resuming or finetune.""" output_folder = self.config.output_dir checkpoint_dir = os.path.join( - output_folder, DEFAULT_CHECKPOINT_DIR, 'rank_{}'.format(self.rank_id)) + output_folder, DEFAULT_CHECKPOINT_DIR, f'rank_{self.rank_id}') return get_last_checkpoint(checkpoint_dir, self.config.load_ckpt_format) @staticmethod @@ -1202,7 +1204,7 @@ class Trainer: """ Initializes a git repo in `self.config.hub_model_id`. """ - from modelfoundry_hub import create_repo + from modelfoundry_hub import create_repo # pylint: disable=import-outside-toplevel if self.config.rank_id: return @@ -1268,10 +1270,10 @@ class Trainer: config_dict = _reset_config_for_save(config) if config_dir is None: config_dir = os.path.join( - self.configs_directory, model_name.lower() + '_new') + self.configs_directory, f"{model_name.lower()}_new") if not os.path.exists(config_dir): os.makedirs(config_dir, exist_ok=True) - run_yaml_path = os.path.join(config_dir, 'run_{}.yaml'.format(model_name.lower())) + run_yaml_path = os.path.join(config_dir, f"run_{model_name.lower()}.yaml") _save_config_to_yaml(run_yaml_path, config_dict) @@ -1327,7 +1329,7 @@ class Trainer: def _check_config_rules(self): """Check config rules.""" if not self.config.load_checkpoint and self.config.pretrained_model_dir: - from mindformers.utils import contains_safetensors_files + from mindformers.utils import contains_safetensors_files # pylint: disable=import-outside-toplevel if contains_safetensors_files(self.config.pretrained_model_dir): self.config.load_checkpoint = self.config.pretrained_model_dir logger.info(f'Parameter load_checkpoint does not set the weight path default read from ' @@ -1337,12 +1339,13 @@ class Trainer: f'does not contain any safetensors file and load_checkpoint is empty.' f'It will not load any weights.') - if self.config.auto_trans_ckpt and self.config.load_ckpt_format == 'ckpt': + if self.config.auto_trans_ckpt: if not is_publicly_accessible_path(get_output_root_path()): - raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True," - "the output_dir should be a shared directory that can be accessed by all nodes." - f"but {os.path.abspath(self.config.output_dir)} is not a shared directory.") - clear_auto_trans_output(self.config.load_checkpoint, self.config.src_strategy_path_or_dir) + raise ValueError(f"When device num > {get_device_num_per_node()} and auto_trans_ckpt is set to True, " + f"the output_dir should be a shared directory that can be accessed by all nodes. " + f"But {os.path.abspath(self.config.output_dir)} is not a shared directory.") + clear_auto_trans_output( + self.config.load_checkpoint, self.config.src_strategy_path_or_dir, self.config.load_ckpt_format) if (self.config.auto_trans_ckpt or self.config.resume_training) and not self.config.load_checkpoint: if self.config.model and self.config.model.model_config.checkpoint_name_or_path: @@ -1367,10 +1370,14 @@ class Trainer: def _error_if_checkpoint_prefix_contains_rank_info(self): """Error if checkpoint prefix contains rank info""" - for callback in self.config.callbacks: - if "type" in callback and callback["type"] == "CheckpointMonitor": - if "rank" in callback.get("prefix", "mindformers"): - raise ValueError("The prefix for saving checkpoint is not allowed to contain 'rank'.") + if hasattr(self.config, "callbacks") and self.config.callbacks: + if not isinstance(self.config.callbacks, list): + raise ValueError("Expected 'callbacks' in config to be a list, " + f"but get {type(self.config.callbacks)}.") + for callback in self.config.callbacks: + if "type" in callback and callback["type"] == "CheckpointMonitor": + if "rank" in callback.get("prefix", "mindformers"): + raise ValueError("The prefix for saving checkpoint is not allowed to contain 'rank'.") def _check_args_task_and_model(self): """Check args, task and model.""" @@ -1435,7 +1442,7 @@ class Trainer: The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the progress of the commit if `blocking=True`. """ - from modelfoundry_hub import upload_folder + from modelfoundry_hub import upload_folder # pylint: disable=import-outside-toplevel if self.hub_model_id is None: self.init_openmind_repo() diff --git a/mindformers/trainer/utils.py b/mindformers/trainer/utils.py index f07218054..7ab9fa492 100644 --- a/mindformers/trainer/utils.py +++ b/mindformers/trainer/utils.py @@ -27,6 +27,7 @@ from mindspore import context, load_checkpoint, load_param_into_net, load_checkp from mindspore import set_seed as ms_set_seed from mindspore import Parameter from mindspore import ops, mint +from mindspore import save_checkpoint from mindformers.tools.logger import logger from mindformers.tools.utils import get_real_rank @@ -270,7 +271,7 @@ def get_distribute_checkpoint_path(checkpoint_dir, rank_id=None, ckpt_format='ck "load_checkpoint should be a checkpoint directory containing the directory of rank_{0-*}," "The directory structure is as follows: **checkpoint_root_dir/rank_{0-*}/**.ckpt") rank_id = rank_id if rank_id is not None else get_real_rank() - distribute_checkpoint_dir = os.path.join(checkpoint_dir, "rank_{}".format(rank_id)) + distribute_checkpoint_dir = os.path.join(checkpoint_dir, f"rank_{rank_id}") distribute_checkpoint_path = get_last_checkpoint(distribute_checkpoint_dir, ckpt_format) logger.info("distribute checkpoint dir: %s", distribute_checkpoint_dir) elif os.path.isfile(checkpoint_dir): @@ -308,7 +309,9 @@ def load_resume_context_from_checkpoint(config, dataset): raise FileNotFoundError(f"The load_checkpoint must be correct, but get {config.load_checkpoint}") if os.path.isdir(config.load_checkpoint): - if config.use_graceful_exit: + # When graceful exit is enabled or auto checkpoint transformation is disabled, + # use real rank ID to locate the checkpoint. + if config.use_graceful_exit or not config.auto_trans_ckpt: rank_id = get_real_rank() else: rank_id = 0 @@ -472,7 +475,7 @@ def check_checkpoint_config_valid(config): def check_rank_folders(path, rank_id): """check if the folders in path are correct""" - folder_name = "rank_{}".format(rank_id) + folder_name = f"rank_{rank_id}" if not os.path.exists(os.path.join(path, folder_name)): return False return True @@ -534,14 +537,14 @@ def load_slora_ckpt(checkpoint_dict, config, network): adapter_path = os.path.join(pet_config.adapter_path, "lora_adapter.json") if not os.path.exists(adapter_path): raise FileNotFoundError(f"The adapter_path must be correct, but get {adapter_path}") - with open(adapter_path, 'r') as file: + with open(adapter_path, 'r', encoding='utf-8') as file: path_dict = json.load(file) adapter_list = [] config_list = [] for adapter_name in network.lora_adapter.adapter_names[1:]: if adapter_name in path_dict.keys(): adapter_model = load_checkpoint(os.path.join(path_dict[adapter_name], "adapter_model.ckpt")) - with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r') as file: + with open(os.path.join(path_dict[adapter_name], "adapter_config.json"), 'r', encoding='utf-8') as file: adapter_config = json.load(file) else: adapter_model = {} @@ -577,7 +580,6 @@ def load_slora_ckpt(checkpoint_dict, config, network): dst_checkpoint_dir = pet_config.adapter_path if config.auto_trans_ckpt: # Save collected lora weights as single ckpt - from mindspore import save_checkpoint src_checkpoint_dir = os.path.join(config.output_dir, "slora_checkpoint") os.makedirs(src_checkpoint_dir, exist_ok=True) src_checkpoint_dir = os.path.join(src_checkpoint_dir, "slora.ckpt") diff --git a/mindformers/utils/__init__.py b/mindformers/utils/__init__.py index 4699a8543..4af43ef06 100644 --- a/mindformers/utils/__init__.py +++ b/mindformers/utils/__init__.py @@ -27,5 +27,15 @@ from .safetensors import ( is_hf_safetensors_dir, check_safetensors_key, ) -from .load_checkpoint_utils import validate_qkv_concat, process_hf_checkpoint from .decorators import deprecated +from .process_utils import barrier_world +from .file_utils import ( + set_output_path, + set_strategy_save_path, + is_publicly_accessible_path, + remake_folder, + remove_folder, + create_file, + delete_file, + clear_auto_trans_output +) diff --git a/mindformers/utils/file_utils.py b/mindformers/utils/file_utils.py new file mode 100644 index 000000000..468dfdec3 --- /dev/null +++ b/mindformers/utils/file_utils.py @@ -0,0 +1,199 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" File utils.""" +import os +import shutil + +from mindspore import context +from mindspore._checkparam import args_type_check + +from mindformers.tools.logger import logger +from mindformers.tools.utils import ( + Validator, + FILE_PERMISSION, + check_in_modelarts, + get_real_rank, + set_safe_mode_for_file_or_dir, + get_output_root_path, + get_real_group_size, + get_device_num_per_node, + get_remote_save_url, + check_shared_disk, + is_main_rank +) + +if check_in_modelarts(): + import moxing as mox +else: + mox = None + + +@args_type_check(path=str) +def set_output_path(path): + """set output path""" + if path is None: + path = './output' + expanduser_path = os.path.expanduser(path) + os.environ['LOCAL_DEFAULT_PATH'] = os.path.realpath(expanduser_path) + logger.info(f"set output path to '{os.path.realpath(expanduser_path)}'") + + +def set_strategy_save_path(config): + """set strategy path""" + rank_id = get_real_rank() + strategy_ckpt_save_dir = os.path.join(get_output_root_path(), "strategy") + os.makedirs(strategy_ckpt_save_dir, exist_ok=True) + set_safe_mode_for_file_or_dir(strategy_ckpt_save_dir) + + strategy_ckpt_save_file = config.get('strategy_ckpt_save_file', "ckpt_strategy.ckpt") + if not strategy_ckpt_save_file.endswith(f"_rank_{rank_id}.ckpt"): + strategy_name = os.path.basename(strategy_ckpt_save_file).replace(".ckpt", f"_rank_{rank_id}.ckpt") + config['strategy_ckpt_save_file'] = os.path.join(strategy_ckpt_save_dir, strategy_name) + context.set_auto_parallel_context(strategy_ckpt_save_file=config['strategy_ckpt_save_file']) + logger.info(f"set strategy path to '{config['strategy_ckpt_save_file']}'") + + +def is_publicly_accessible_path(path): + """Check a path is accessible by all rank.""" + if get_real_group_size() <= get_device_num_per_node(): + return True + + if check_in_modelarts(): + return True + + if check_shared_disk(path): + return True + + # For example, SHARED_PATHS="/mnt/shared1,/mnt/shared2", which will be split by "/mnt/shared1" and "/mnt/shared2". + shared_paths = os.getenv("SHARED_PATHS", "").split(',') + path = os.path.realpath(path) + for shared_path in shared_paths: + if not shared_path: + continue + shared_path = os.path.realpath(shared_path) + if path.startswith(shared_path): + return True + logger.info("System can not identify if given path is shared disk. " + "If it is, Please set env 'SHARED_PATHS' to given path.") + return False + + +def remake_folder(folder_path, permissions=None): + """make folder""" + rank_id = get_real_rank() + logger.info("Remake %s...", folder_path) + if Validator.is_obs_url(folder_path): + if not check_in_modelarts(): + raise ValueError(f"When remaking {folder_path}, \ + it is detected that it is not in the ModelArts platform.") + if not rank_id: + if mox.file.exists(folder_path): + mox.file.remove(folder_path, recursive=True) + mox.file.make_dirs(folder_path) + logger.info("OBS: Folder %s is remaked.", folder_path) + else: + if is_main_rank(): + if os.path.exists(folder_path): + shutil.rmtree(folder_path) + os.makedirs(folder_path, exist_ok=True) + os.chmod(folder_path, permissions) + logger.info("Folder %s is remaked.", folder_path) + + +def remove_folder(folder_path, rank_id=None): + """delete folder""" + rank_id = rank_id or get_real_rank() + logger.info("Remove %s...", folder_path) + if Validator.is_obs_url(folder_path): + if not check_in_modelarts(): + raise ValueError(f"When removing {folder_path}, \ + it is detected that it is not in the ModelArts platform.") + if mox.file.exists(folder_path) and not rank_id: + mox.file.remove(folder_path, recursive=True) + logger.info("OBS: Folder %s is removed.", folder_path) + else: + if os.path.exists(folder_path) and is_main_rank(): + shutil.rmtree(folder_path) + logger.info("Folder %s is removed.", folder_path) + + +def create_file(file_path, info=None): + """create file.""" + if Validator.is_obs_url(file_path): + if not check_in_modelarts(): + raise ValueError(f"When create {file_path}, \ + it is detected that it is not in the ModelArts platform.") + with mox.file.File(file_path, 'w') as f: + if info: + if isinstance(info, list): + for sub_info in info: + f.write(str(sub_info) + "\n") + else: + f.write(info) + else: + f.write("Hugging ModelArts.") + else: + flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + with os.fdopen(os.open(file_path, flags_, FILE_PERMISSION), 'w') as f: + if info: + if isinstance(info, list): + for sub_info in info: + f.write(str(sub_info) + "\n") + else: + f.write(info) + + +def delete_file(file_path): + """delete file""" + if Validator.is_obs_url(file_path): + if not check_in_modelarts(): + raise ValueError(f"When create {file_path}, \ + it is detected that it is not in the ModelArts platform.") + if mox.file.exists(file_path): + mox.file.remove(file_path, recursive=False) + else: + if os.path.exists(file_path): + os.remove(file_path) + + +def clear_auto_trans_output(load_checkpoint=None, src_strategy_path_or_dir=None, load_ckpt_format='ckpt'): + """ + Clear transformed_checkpoint and strategy folders based on the specified checkpoint format. + + Args: + load_checkpoint (str, optional): Path to the checkpoint file/directory to load. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + src_strategy_path_or_dir (str, optional): Path to the source strategy file/directory. + If the target folder path matches this value, a ValueError is raised. Defaults to None. + load_ckpt_format (str, optional): Format of the checkpoint file. Supports 'ckpt' and 'safetensors'. + Determines which transformed checkpoint folder to clear. Defaults to 'ckpt'. + + Raises: + ValueError: If the resolved folder path (strategy or transformed/unified checkpoint) is the same as + `load_checkpoint` or `src_strategy_path_or_dir`. This prevents overwriting critical input data. + """ + folder_list = ["strategy", "transformed_checkpoint"] if load_ckpt_format == 'ckpt' \ + else ["strategy", "unified_checkpoint"] + for folder in folder_list: + if check_in_modelarts(): + folder_path = os.path.join(get_remote_save_url(), folder) + else: + folder_path = os.path.join(get_output_root_path(), folder) + if os.path.realpath(folder_path) in (load_checkpoint, src_strategy_path_or_dir): + raise ValueError( + "./transformed_checkpoint or ./strategy with given config.output_dir is same as " + "load_checkpoint or src_strategy_path_or_dir which is not allowed when auto_trans is True." + "Please move it to a different location or specify a different output folder.") + remake_folder(folder_path, permissions=0o750) diff --git a/mindformers/utils/load_checkpoint_utils.py b/mindformers/utils/load_checkpoint_utils.py index 66ed527b0..a134e4bbd 100644 --- a/mindformers/utils/load_checkpoint_utils.py +++ b/mindformers/utils/load_checkpoint_utils.py @@ -29,14 +29,12 @@ from mindspore.common.api import _pynative_executor from mindspore.communication.comm_func import barrier from mindformers.tools.logger import logger -from mindformers.tools.utils import ( - is_main_rank, - get_real_rank, - clear_auto_trans_output -) +from mindformers.tools.utils import is_main_rank, get_real_rank from mindformers.utils import convert_hf_safetensors_multiprocess, check_safetensors_key, is_hf_safetensors_dir from mindformers.utils.safetensors.convert_safetensors import _convert_index_json from mindformers.version_control import check_safetensors_addition_param_support +from mindformers.models.modeling_utils import PreTrainedModel +from mindformers.parallel_core.inference.utils import generate_state_dict, save_strategy_file from ..version_control import check_tft_valid @@ -172,9 +170,9 @@ def extract_suffix(file_path): match = re.match(pattern, base_name) if not match: - logger.info(f"Filename '{filename}' does not match expected pattern. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' does not match expected pattern. " + f"Use '{base_name}' as the suffix.") + return base_name # Extract matched groups task_id = match.group(1) # Will be None if no task_id in filename @@ -182,9 +180,9 @@ def extract_suffix(file_path): step = match.group(3) if not epoch or not step: - logger.info(f"Filename '{filename}' is missing epoch or step information. " - "Skipping suffix extraction.") - return None + logger.warning(f"Filename '{filename}' is missing epoch or step information. " + f"Use '{base_name}' as the suffix.") + return base_name # Construct the appropriate suffix based on presence of task_id if task_id: @@ -194,25 +192,26 @@ def extract_suffix(file_path): def _get_src_file_suffix(config): """get file_suffix from config.load_checkpoint.""" - if isinstance(config.resume_training, str): - file_suffix = extract_suffix(config.resume_training) - return config.load_checkpoint, file_suffix - - if os.path.isfile(config.load_checkpoint): - # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} - file_suffix = extract_suffix(config.load_checkpoint) - checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) - return checkpoint_dir, file_suffix + checkpoint_dir = None + file_suffix = None + if is_main_rank(): + if isinstance(config.resume_training, str): + checkpoint_dir = config.load_checkpoint + checkpoint_name = config.resume_training.split('.')[0] + f'.{config.load_ckpt_format}' + elif os.path.isfile(config.load_checkpoint): + # only support path format: path/rank_x/prefix-{epoch}_{step}.{config.load_ckpt_format} + checkpoint_dir = '/'.join(config.load_checkpoint.split('/')[:-2]) + checkpoint_name = os.path.basename(config.load_checkpoint) + else: + checkpoint_dir = config.load_checkpoint + rank_path = f"{checkpoint_dir}/rank_0" + last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) + checkpoint_name = os.path.basename(last_checkpoint) - # config.load_checkpoint is folder - rank_id = get_real_rank() - rank_path = f"{config.load_checkpoint}/rank_{rank_id}" - if not os.path.exists(rank_path): - raise FileNotFoundError(f"{rank_path} not found.") + file_suffix = extract_suffix(checkpoint_name) + logger.info(f"checkpoint name suffix: {file_suffix}") - last_checkpoint = get_last_checkpoint(rank_path, config.load_ckpt_format) - file_suffix = extract_suffix(last_checkpoint) - return config.load_checkpoint, file_suffix + return checkpoint_dir, file_suffix def _get_src_file(checkpoint_dir, checkpoint_name=None, ckpt_format='ckpt'): @@ -273,19 +272,19 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval load_checkpoint_files = [] strategy_path = ms.get_auto_parallel_context('strategy_ckpt_save_file') if ckpt_file_mode == CheckpointFileMode.SINGLE_CHECKPOINT_FILE.value: - logger.info(f"......Use single checkpoint file mode......") + logger.info("......Use single checkpoint file mode......") load_checkpoint_files = [config.load_checkpoint] if ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE.value: - logger.info(f"......Use multi checkpoint file mode......") + logger.info("......Use multi checkpoint file mode......") load_checkpoint_files = glob( os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() config.remove_redundancy = False elif ckpt_file_mode == CheckpointFileMode.MULTI_CHECKPOINT_FILE_WITH_RANK_ID.value: - logger.info(f"......Use multi checkpoint file with rank id mode......") + logger.info("......Use multi checkpoint file with rank id mode......") # change strategy if config.auto_trans_ckpt: - logger.info(f"......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") + logger.info("......auto_trans is True, will unify all rank files and slice to dst parallel strategy......") src_strategy_path = get_merged_src_strategy_path(config) unified_safetensors_path = os.path.join(config.output_dir, 'unified_checkpoint/') load_checkpoint, file_suffix = _get_src_file_suffix(config) @@ -300,7 +299,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval os.path.join(load_checkpoint, f"*.{config.load_ckpt_format}")) load_checkpoint_files.sort() else: - logger.info(f"......auto_trans is False, will not unify or slice rank files......") + logger.info("......auto_trans is False, will not unify or slice rank files......") if check_tft_valid() and not config.remove_redundancy: sf_file_name = load_checkpoint logger.info(f"......tft is enabled and not enable remove_redundancy, sf_file_name={sf_file_name}......") @@ -316,7 +315,7 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval network = model._train_network #build model if config.use_parallel: - logger.info(f"......Start build model in parallel mode......") + logger.info("......Start build model in parallel mode......") build_model(config, model, input_data, do_eval=do_eval, do_predict=do_predict) #wait generate all rank strategy files barrier() @@ -337,9 +336,8 @@ def load_checkpoint_with_safetensors(config, model, network, input_data, do_eval def process_for_stand_alone_mode(config, network, strategy_path): """process for stand alone mode""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and enable_stand_alone: - from mindformers.parallel_core.inference.utils import generate_state_dict, save_strategy_file strategy_ckpt_save_dir = os.path.dirname(strategy_path) if is_main_rank(): if os.path.exists(strategy_ckpt_save_dir): @@ -377,9 +375,9 @@ def validate_config_with_file_mode(ckpt_file_mode, use_parallel, auto_trans_ckpt def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_parallel=False, file_suffix=None, remove_redundancy=False): """merge strategy and unified safetensors.""" - logger.info("Start unify safetensors.") if is_main_rank(): # unify checkpoints + logger.info("Start unify safetensors.") logger.info(f"unified safetensors with file_suffix:{file_suffix}, remove_redundancy: {remove_redundancy}") logger.info(f"unified safetensors with save path:{unified_path}") unify_time_start = time.time() @@ -392,7 +390,10 @@ def unify_safetensors(src_checkpoint, src_strategy_path, unified_path, use_paral ) unify_time_end = time.time() logger.info("Time spent unifying safetensors: %.2fs", unify_time_end - unify_time_start) - clear_auto_trans_output() + else: + logger.info("Wait for rank_0 to unify the safetensors, " + "please check the log of rank_0 to get the unify progress.") + if use_parallel: barrier() logger.info("Unified safetensors finished.") @@ -437,7 +438,7 @@ def load_safetensors_checkpoint(config, load_checkpoint_files, network, strategy ms.load_param_into_net(optimizer, hyper_param_dict) else: logger.info("......Start load checkpoint to model......") - params_dict = dict() + params_dict = {} remove_redundancy = config.get('remove_redundancy', False) for checkpoint_file in load_checkpoint_files: logger.info(f"load checkpoint file: {checkpoint_file}") @@ -556,7 +557,6 @@ def validate_qkv_concat(model_cls_or_instance, qkv_concat_config, load_checkpoin Currently only safetensors format is supported. """ # check the type of model_cls_or_instance - from mindformers.models.modeling_utils import PreTrainedModel if not ( isinstance(model_cls_or_instance, PreTrainedModel) or (isinstance(model_cls_or_instance, type) and issubclass(model_cls_or_instance, PreTrainedModel)) @@ -614,7 +614,7 @@ def get_merged_src_strategy_path(config): def get_merged_dst_strategy_path(config, strategy_path): """prepare for dst strategy.""" - enable_stand_alone = (config.parallel.parallel_mode == 'STAND_ALONE') + enable_stand_alone = config.parallel.parallel_mode == 'STAND_ALONE' if config.use_parallel and config.auto_trans_ckpt and not enable_stand_alone: # prepare merged strategy directory merged_strategy = os.path.join(config.output_dir, 'merged_strategy') diff --git a/mindformers/utils/process_utils.py b/mindformers/utils/process_utils.py new file mode 100644 index 000000000..023ad7cd8 --- /dev/null +++ b/mindformers/utils/process_utils.py @@ -0,0 +1,30 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Process utils.""" +from mindspore.communication import comm_func + +from mindformers.tools.logger import logger +from mindformers.tools.utils import get_real_group_size + + +def barrier_world(action: str = None): + """barrier all rank until action is done""" + if get_real_group_size() > 1: + if action is not None: + logger.info("Wait " + str(action)) + else: + logger.info("Now barriered...") + + comm_func.barrier() diff --git a/mindformers/utils/resume_ckpt_utils.py b/mindformers/utils/resume_ckpt_utils.py index 587f65031..82628be9b 100644 --- a/mindformers/utils/resume_ckpt_utils.py +++ b/mindformers/utils/resume_ckpt_utils.py @@ -19,7 +19,8 @@ from mindspore import load_checkpoint from mindformers.tools.logger import logger from mindformers.tools.resume_ckpt import get_resume_checkpoint_by_meta -from mindformers.tools.utils import get_real_rank, is_publicly_accessible_path, replace_rank_id_in_ckpt_name +from mindformers.tools.utils import get_real_rank, replace_rank_id_in_ckpt_name +from mindformers.utils.file_utils import is_publicly_accessible_path from mindformers.trainer.utils import get_last_checkpoint, is_hyper_param_existed_in_sf_dir diff --git a/run_mindformer.py b/run_mindformer.py index 5d4476257..9183d1dbc 100644 --- a/run_mindformer.py +++ b/run_mindformer.py @@ -22,7 +22,7 @@ from mindformers.tools.utils import str2bool, parse_value, str2bool_or_str from mindformers.core.context import build_context from mindformers.trainer import Trainer from mindformers.tools.logger import logger -from mindformers.tools import set_output_path +from mindformers.utils.file_utils import set_output_path SUPPORT_MULTI_MODAL_FILETYPES = { "video": (".mp4", ".avi", ".mkv"), @@ -66,7 +66,7 @@ def main(config): build_context(config) trainer = Trainer(config) - if config.run_mode == 'train' or config.run_mode == 'finetune': + if config.run_mode in ('train', 'finetune'): trainer.train() elif config.run_mode == 'eval': trainer.evaluate(eval_checkpoint=config.load_checkpoint) @@ -217,7 +217,7 @@ if __name__ == "__main__": for item in rest_args_ for i in item.split("=")] if len(rest_args_) % 2 != 0: - raise ValueError(f"input arg key-values are not in pair, please check input args. ") + raise ValueError("input arg key-values are not in pair, please check input args. ") if args_.config is not None and not os.path.isabs(args_.config): args_.config = os.path.join(work_path, args_.config) diff --git a/tests/st/test_safetensors/test_qkv_check.py b/tests/st/test_safetensors/test_qkv_check.py index 503dc8302..10326f1ff 100644 --- a/tests/st/test_safetensors/test_qkv_check.py +++ b/tests/st/test_safetensors/test_qkv_check.py @@ -24,7 +24,7 @@ from safetensors.numpy import save_file from mindformers import LlamaForCausalLM, ChatGLM2Model from mindformers.tools.register import MindFormerConfig -from mindformers.utils import validate_qkv_concat +from mindformers.utils.load_checkpoint_utils import validate_qkv_concat class TestValidateQKVConcat: @@ -151,7 +151,7 @@ class TestValidateQKVConcat: assert "does not support qkv concat check" in log_content def _get_log_content(self): - with open(self.log_file_path, 'r') as log_file: + with open(self.log_file_path, 'r', encoding='utf-8') as log_file: log_content = log_file.read() return log_content diff --git a/toolkit/benchmarks/eval_with_harness.py b/toolkit/benchmarks/eval_with_harness.py index c1baed041..92a667fd8 100644 --- a/toolkit/benchmarks/eval_with_harness.py +++ b/toolkit/benchmarks/eval_with_harness.py @@ -24,6 +24,7 @@ from tqdm import tqdm import mindspore from mindspore import Model, Tensor from mindspore.common import initializer +from mindspore.nn.utils import no_init_parameters from lm_eval import utils from lm_eval.__main__ import cli_evaluate @@ -42,7 +43,7 @@ from mindformers import ( AutoTokenizer ) from mindformers.trainer.utils import transform_and_load_checkpoint -from mindformers.tools import set_output_path +from mindformers.utils.file_utils import set_output_path from mindformers.version_control import check_delay_initialization_support from mindformers.utils.load_checkpoint_utils import get_load_path_after_hf_convert @@ -179,7 +180,7 @@ class MFLM(TemplateLM): self._config.model.model_config.moe_config = self._config.moe_config build_context(self._config) - eval_logger.info(f"Build context finished.") + eval_logger.info("Build context finished.") build_parallel_config(self._config) return self._config @@ -190,12 +191,11 @@ class MFLM(TemplateLM): sig = False if check_delay_initialization_support(): sig = True - from mindspore.nn.utils import no_init_parameters with no_init_parameters(): # Delay initialization self._model = AutoModel.from_config(config) else: self._model = AutoModel.from_config(config) - eval_logger.info(f"Build model finished.") + eval_logger.info("Build model finished.") if not config.load_checkpoint: raise Exception("There is no model ckpt in the model directory.") @@ -213,7 +213,7 @@ class MFLM(TemplateLM): def _create_tokenizer(self, pretrained: str) -> None: """Initialize Tokenizer""" self.tokenizer = AutoTokenizer.from_pretrained(pretrained) - eval_logger.info(f"Build tokenizer finished.") + eval_logger.info("Build tokenizer finished.") def tok_encode( self, string: str, left_truncate_len: Optional[int] = None, add_special_tokens=None @@ -420,8 +420,8 @@ class MFLM(TemplateLM): if not continuation_enc: raise ValueError("continuation_enc must not be None") if len(continuation_enc) > self.max_length: - raise ValueError("The length of continuation_enc must be less than \ - or equal to max_length, but got {}".format(len(continuation_enc))) + raise ValueError("The length of continuation_enc must be less than " + f"or equal to max_length, but got {len(continuation_enc)}") # how this all works (illustrated on a causal decoder-only setup): # CTX CONT @@ -463,7 +463,7 @@ class MFLM(TemplateLM): # (discard context toks if decoder-only ; discard right-padding) # also discards + checks for "virtual tokens" in the causal LM's input window # from prompt/prefix tuning tokens, if applicable - ctx_len = (inplen + (logits.shape[0] - padding_len_inp)) + ctx_len = inplen + (logits.shape[0] - padding_len_inp) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = logits.unsqueeze(0) # [1, seq, vocab] -- Gitee