From 8c9e7c915addb99e26e93641ddedbf0235310a13 Mon Sep 17 00:00:00 2001 From: yiyison Date: Wed, 19 Nov 2025 10:33:06 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9D=83=E9=87=8D2.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mindformers/checkpoint/checkpoint.py | 94 +++++------ mindformers/checkpoint/fully_parallel.py | 156 ++++++++++-------- mindformers/checkpoint/metadata.py | 34 +--- mindformers/checkpoint/sharded_tensor.py | 50 +++++- mindformers/checkpoint/utils.py | 18 +- mindformers/core/callback/callback.py | 68 ++++---- mindformers/core/config_args.py | 2 +- mindformers/tools/register/template.py | 2 +- mindformers/trainer/base_trainer.py | 28 ++-- mindformers/trainer/trainer.py | 10 +- tests/st/test_ut/base_schema.json | 2 +- .../test_tensorboard/test_tensorboard.py | 5 +- 12 files changed, 246 insertions(+), 223 deletions(-) diff --git a/mindformers/checkpoint/checkpoint.py b/mindformers/checkpoint/checkpoint.py index 71356b0bd..2a274e966 100644 --- a/mindformers/checkpoint/checkpoint.py +++ b/mindformers/checkpoint/checkpoint.py @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================ """load/save checkpoint apis.""" - import os import json import tempfile @@ -33,7 +32,6 @@ from mindspore.nn.optim.optimizer import Optimizer from mindspore.communication.management import get_rank, get_group_size from mindspore.communication import comm_func from mindspore import save_checkpoint as ms_save_checkpoint -from mindspore.parallel.strategy import get_current_strategy_metadata from mindformers.tools.logger import logger from mindformers.checkpoint.reshard import ReshardHandler @@ -43,6 +41,7 @@ from mindformers.tools.utils import ( get_output_subpath, get_real_rank, set_safe_mode_for_file_or_dir, + get_real_group_size ) from mindformers.checkpoint.utils import ( get_checkpoint_iter_dir, @@ -55,20 +54,17 @@ from mindformers.checkpoint.utils import ( verify_ckpt_valid, FileType ) -from mindformers.checkpoint.fully_parallel import BalancedSaveStrategy +from mindformers.checkpoint.fully_parallel import BalancedSaveStrategy, apply_balance_shard_strategy from mindformers.checkpoint.metadata import ( save_metadata, load_metadata, generate_default_metadata_from_checkpoint, - get_total_shard_metadata, - get_total_params_file_mapping_info + get_total_params_file_mapping_info, ) from mindformers.checkpoint.sharded_tensor import ( - get_sharded_tensor_list_from_strategy_metadata, - get_sharded_tensor_list_from_cell, convert_sharded_tensor_list_to_dict, get_strategy_info_from_sharded_tensor, - ShardedTensor + ShardedTensor, get_sharded_tensor_list_from_cell, get_cur_sharded_tensor ) @@ -237,7 +233,7 @@ class AsyncSaveManager: # Async thread for thread in threading.enumerate(): - if thread.getName() == "asyn_save_ckpt": + if thread.name == "asyn_save_ckpt": if wait_finish: thread.join() return False @@ -267,7 +263,7 @@ class AsyncSaveManager: def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, async_save_manager: AsyncSaveManager = None, common_info: CommonInfo = None, keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None, - global_strategy_info: List[Dict] = None, remove_redundancy: bool = False): + sharded_tensor_metas: list = None, remove_redundancy: bool = False): """ Saves the current state of the training process, including the model, optimizer, and learning rate scheduler, to a checkpoint file. @@ -283,7 +279,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, save_checkpoint_path (str): The user can specify the path to save the weights. If None, the default path is 'output_dir/checkpoint'. And 'output_dir' is configured in yaml and defaults to './output' in the execution script path. - global_strategy_info (List[Dict]): The strategy info of this network. + sharded_tensor_metas (List): The ShardedTensor metas of this network. remove_redundancy (bool): Whether to remove redundancy of saving checkpoint. """ logger.info('....... Start to save checkpoint as new format .......') @@ -351,7 +347,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, model_keys = network.parameters_dict().keys() start_save_ckpt_time = time() - if remove_redundancy and global_strategy_info is not None: + if remove_redundancy and sharded_tensor_metas is not None: remove_model_redundancy = BalancedSaveStrategy( network, user_prefix=user_prefix, @@ -374,7 +370,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, # Save optimizer weight. if optimizer is not None: - if remove_redundancy and global_strategy_info is not None: + if remove_redundancy and sharded_tensor_metas is not None: # Optimizer weight remove redundancy. remove_optimizer_redundancy = BalancedSaveStrategy( optimizer, @@ -415,7 +411,7 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, # Save 'metadata.json'. if not remove_redundancy: metadata_file_path = get_metadata_filename(checkpoints_root_path, iteration) - save_metadata_json(global_strategy_info, model_keys, user_prefix, metadata_file_path, optimizer is not None) + save_metadata_json(sharded_tensor_metas, model_keys, user_prefix, metadata_file_path) # Save tracker file in sync save process. if not use_async_save: @@ -426,16 +422,11 @@ def save_checkpoint(iteration: int, network: Cell, optimizer: Optimizer = None, logger.info(f"Save checkpoint cost time: {time() - start_save_ckpt_time:.3f}s.") -def save_metadata_json(global_strategy_info, model_keys, user_prefix, metadata_file_path, save_optimizer): +def save_metadata_json(sharded_tensor_metas, model_keys, user_prefix, metadata_file_path): """Saving metadata.json used `get_strategy_metadata` API.""" - if global_strategy_info is not None: + if sharded_tensor_metas is not None: logger.info("...... Start saving metadata ......") - if get_rank() == 0: - sharded_tensor_metas = get_total_shard_metadata( - global_strategy_info=global_strategy_info, - filter_func=(lambda x: x in list(model_keys)) if not save_optimizer else None - ) param_file_mappings = get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_keys) save_metadata(sharded_tensor_metas, param_file_mappings, metadata_file_path) @@ -666,7 +657,7 @@ def load_tensor_by_offset( def categorize_params( - dst_sharded_tensor_metas: Dict[str, List[ShardedTensor]], + dst_sharded_tensor_metas: Dict[str, ShardedTensor], src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], param_file_mappings: Dict[str, List[Dict[str, Any]]] ) -> Tuple[List[str], Dict[str, Dict[str, List[Any]]], Dict[str, Dict[str, List[Any]]], Dict[str, List[Any]]]: @@ -900,7 +891,7 @@ def load_checkpoint( network: Cell, optimizer: Optional[Optimizer] = None, global_step: Optional[int] = None, - balanced_load: bool = False, + balanced_load: bool = False ) -> None: """ Loads a checkpoint into a network and optional optimizer. @@ -920,7 +911,7 @@ def load_checkpoint( (Other exceptions may be raised by dependent functions for checkpoint validation/loading) """ # Validate mandatory network parameter - check_the_param_for_load_ckpt(balanced_load, checkpoint, network) + check_the_param_for_load_ckpt(checkpoint, network) # Determine checkpoint directory path checkpoint_dir = get_checkpoint_path(checkpoint) @@ -947,31 +938,23 @@ def load_checkpoint( "Metadata must include both sharded tensor information and parameter-file mappings." ) - # Get current strategy metadata from network and optimizer - logger.info(".........Get Current Strategy Metadata.........") - cur_rank_strategy_layout = get_current_strategy_metadata(network=network) - cur_rank_sharded_tensors: List[ShardedTensor] = [] + # Define parameter filtering function + def filter_func(param_name: str) -> bool: + if optimizer: + return "accu_grads" not in param_name + return param_name in list(network.parameters_dict().keys()) - if cur_rank_strategy_layout: - # Convert strategy layout to required format - cur_rank_strategy_layout = [dict([item]) for item in cur_rank_strategy_layout[0].items()] - - # Define parameter filtering function - def filter_fun(param_name: str) -> bool: - if optimizer: - return "accu_grads" not in param_name - return param_name in list(network.parameters_dict().keys()) - - # Get sharded tensors from strategy metadata - cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata( - param_infos=cur_rank_strategy_layout, cur_npu_rank=get_real_rank(), filter_func=filter_fun - ) + if balanced_load: + dst_sharded_tensor_metas = apply_balance_shard_strategy(network, filter_func)[-1] else: - # Fallback: Get sharded tensors directly from network and optimizer - cur_rank_sharded_tensors = get_sharded_tensor_list_from_cell(network, optimizer) + if get_real_group_size() > 1: + cur_rank_sharded_tensors = get_cur_sharded_tensor(network, filter_func) + else: + # Fallback: Get sharded tensors directly from network and optimizer + cur_rank_sharded_tensors = get_sharded_tensor_list_from_cell(network, optimizer) - # Convert list of sharded tensors to dictionary for lookup - dst_sharded_tensor_metas = convert_sharded_tensor_list_to_dict(cur_rank_sharded_tensors) + # Convert list of sharded tensors to dictionary for lookup + dst_sharded_tensor_metas = convert_sharded_tensor_list_to_dict(cur_rank_sharded_tensors) # Categorize parameters based on sharding strategies _, need_concat_params, no_shard_params, online_shard_params = categorize_params( @@ -1020,7 +1003,7 @@ def load_checkpoint( ) # Load state dictionary into network and optimizer - load_parameters(network, state_dict, optimizer) + load_parameters(network, state_dict, optimizer, balanced_load=balanced_load) def concat_params(checkpoint_dir: str, core_network, key_mapping: dict, need_concat_params, state_dict: dict): @@ -1059,15 +1042,11 @@ def concat_params(checkpoint_dir: str, core_network, key_mapping: dict, need_con state_dict[param_name] = Parameter(concated_weight[param_name], name=param_name, requires_grad=False) -def check_the_param_for_load_ckpt(balanced_load: bool, checkpoint: str, network: Cell): +def check_the_param_for_load_ckpt(checkpoint: str, network: Cell): """Check the params passing in `load_checkpoint` method is legal.""" if network is None: raise ValueError("The 'network' cannot be None - a target network is required for loading.") - if balanced_load: - raise ValueError("The balanced loading strategy is not supported yet. " - "`balanced_load` is a preset switch, please set `balanced_load` to False.") - if not os.path.exists(checkpoint): raise ValueError(f"Checkpoint does not exist: {checkpoint}") @@ -1077,7 +1056,8 @@ def load_parameters( state_dict: Dict[str, Parameter], optimizer: Optional[Cell] = None, state_dict_opt: Optional[Dict[str, Parameter]] = None, -) -> Tuple[List[str], List[str], Optional[List[str]], Optional[List[str]]]: + balanced_load: Optional[bool] = False +): """ Loads parameters into network and optimizer. @@ -1108,7 +1088,7 @@ def load_parameters( # Load parameters into network param_not_load, ckpt_not_load = [], [] logger.debug(f"Network state_dict keys: {list(state_dict.keys())}") - param_not_load, ckpt_not_load = load_param_into_net(network, state_dict) + param_not_load, ckpt_not_load = load_param_into_net(network, state_dict, remove_redundancy=balanced_load) logger.info(f"Network parameters not loaded: {list(param_not_load)}") logger.info(f"Checkpoint weights not loaded: {list(ckpt_not_load)}") @@ -1122,7 +1102,11 @@ def load_parameters( if param_name in optimizer_param_names and param_name not in state_dict_opt: state_dict_opt[param_name] = state_dict.pop(param_name) logger.debug(f"Optimizer state_dict keys: {list(state_dict_opt.keys())}") - param_not_load_opt, ckpt_not_load_opt = load_param_into_net(optimizer, state_dict_opt) + param_not_load_opt, ckpt_not_load_opt = load_param_into_net( + optimizer, + state_dict_opt, + remove_redundancy=balanced_load + ) logger.info(f"Optimizer parameters not loaded: {list(param_not_load_opt)}") logger.info(f"Optimizer weights not loaded: {list(ckpt_not_load_opt)}") diff --git a/mindformers/checkpoint/fully_parallel.py b/mindformers/checkpoint/fully_parallel.py index 07a377de6..730303696 100644 --- a/mindformers/checkpoint/fully_parallel.py +++ b/mindformers/checkpoint/fully_parallel.py @@ -13,24 +13,26 @@ # limitations under the License. # ============================================================================ """save / load parallelization strategy.""" - import os from collections import defaultdict -from mindspore.communication import get_rank from mindspore import save_checkpoint +from mindspore.communication import get_rank +from mindspore.nn import Cell +from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor from mindformers.tools.logger import logger -from mindformers.checkpoint.metadata import save_metadata, load_metadata, get_total_shard_metadata +from mindformers.checkpoint.metadata import save_metadata, load_metadata from mindformers.checkpoint.utils import ( - _sharded_tensor_shard_id, - _get_shard_size, _reverse_sharded_tensor_shard_id, get_checkpoint_iter_dir, get_metadata_filename, get_checkpoint_name, - FileType + FileType, + _get_shard_size, + sharded_tensor_shard_id ) +from mindformers.tools.utils import get_real_local_rank class BalancedSaveStrategy(): @@ -61,7 +63,7 @@ class BalancedSaveStrategy(): """ Initialize the BalancedSaveStrategy object. """ - super(BalancedSaveStrategy, self).__init__() + super().__init__() self.user_prefix = user_prefix self.do_cache_distribution = do_cache_distribution self.total_files_num = None @@ -85,8 +87,8 @@ class BalancedSaveStrategy(): The total number of checkpoint files. """ if self.total_files_num is None: - shared_distribution, shard_to_name = self.apply_saving_parallelization() - rank_params_mappings = self._get_rank_params_mappings(shared_distribution, shard_to_name) + shared_distribution, id_to_tensor = self.apply_saving_parallelization() + rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor) self.total_files_num = self._get_total_files_num(rank_params_mappings) return self.total_files_num @@ -102,8 +104,8 @@ class BalancedSaveStrategy(): The identifier for the current rank's checkpoint file. """ if self.cur_rank_file_id is None: - shared_distribution, shard_to_name = self.apply_saving_parallelization() - rank_params_mappings = self._get_rank_params_mappings(shared_distribution, shard_to_name) + shared_distribution, id_to_tensor = self.apply_saving_parallelization() + rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor) self.cur_rank_file_id = self._get_cur_rank_file_id(rank_params_mappings) return self.cur_rank_file_id @@ -117,8 +119,8 @@ class BalancedSaveStrategy(): and saves the selected parameters in the specified format. It also saves metadata about the checkpoint files if the current rank is 0. """ - shared_distribution, shard_to_name = self.apply_saving_parallelization() - rank_params_mappings = self._get_rank_params_mappings(shared_distribution, shard_to_name) + shared_distribution, id_to_tensor = self.apply_saving_parallelization() + rank_params_mappings = self._get_rank_params_mappings(shared_distribution, id_to_tensor) if self.total_files_num is None: self.total_files_num = self._get_total_files_num(rank_params_mappings) @@ -162,7 +164,8 @@ class BalancedSaveStrategy(): if self.do_cache_distribution and self.cached_distribution is not None: shared_distribution = self.cached_distribution else: - shared_distribution = get_shared_distribution(self.network, self.filter_func) + shard_id_to_ranks, shard_id_to_tensor, _ = apply_balance_shard_strategy(self.network, self.filter_func) + shared_distribution = (shard_id_to_ranks, shard_id_to_tensor) if self.do_cache_distribution: self.cached_distribution = shared_distribution @@ -236,9 +239,7 @@ class BalancedSaveStrategy(): (save_file_name + ".safetensors", rank_id, _reverse_sharded_tensor_shard_id(param_id))) cur_rank_id += 1 - from mindspore.parallel.strategy import get_strategy_metadata - strategy_info = get_strategy_metadata(self.network) - shard_to_metadata = get_total_shard_metadata(strategy_info, self.filter_func) + shard_to_metadata = get_all_sharded_tensor(self.network, self.filter_func) origin_metadata_file = get_metadata_filename(self.checkpoint_path, iteration) if os.path.exists(origin_metadata_file): @@ -247,9 +248,11 @@ class BalancedSaveStrategy(): shard_to_metadata.extend(list(origin_shard_metadata.values())) for param_id, storage in origin_param_file_mapping.items(): for storage_item in storage: - param_file_mapping.append( - (storage_item["file_name"], storage_item["storage_rank"], - _reverse_sharded_tensor_shard_id(param_id))) + param_file_mapping.append(( + storage_item["file_name"], + storage_item["storage_rank"], + _reverse_sharded_tensor_shard_id(param_id) + )) metadata_file_path = get_metadata_filename(self.checkpoint_path, iteration) save_metadata(shard_to_metadata, param_file_mapping, metadata_file_path) @@ -258,7 +261,7 @@ class BalancedSaveStrategy(): f"The 'metadata.json' of non-redundancy weight saved successfully at '{metadata_file_path}'." ) - def _get_rank_params_mappings(self, shared_distribution, shard_to_name): + def _get_rank_params_mappings(self, shared_distribution, id_to_tensor): """ Create a mapping from rank IDs to lists of parameter names based on the shared distribution and shard-to-name mapping. @@ -266,7 +269,7 @@ class BalancedSaveStrategy(): Args: shared_distribution (dict): A dictionary where keys are parameter IDs and values are rank IDs indicating which rank is responsible for a particular parameter. - shard_to_name (dict): A dictionary that maps parameter IDs to their corresponding parameter names. + id_to_tensor (dict): A dictionary that maps parameter IDs to their corresponding ShardTensor. Returns: A dictionary where keys are rank IDs and values are lists of parameter names assigned to that rank. @@ -274,9 +277,9 @@ class BalancedSaveStrategy(): rank_params_mappings = {} for param_id, rank_id in shared_distribution.items(): if rank_id not in rank_params_mappings: - rank_params_mappings[rank_id] = [shard_to_name[param_id]] + rank_params_mappings[rank_id] = [id_to_tensor[param_id].key] else: - rank_params_mappings[rank_id].append(shard_to_name[param_id]) + rank_params_mappings[rank_id].append(id_to_tensor[param_id].key) sorted_rank_params_mappings = { k: rank_params_mappings.get(k, None) for k in sorted(rank_params_mappings) @@ -309,54 +312,6 @@ class BalancedSaveStrategy(): return sorted_rank_params_mappings -def get_shared_distribution(network, filter_func=None): - """ - Get the shared distribution of shards among ranks and the mapping from shard IDs to parameter names. - - This function analyzes the total shard metadata to determine which ranks are responsible for each shard - and assigns shards to ranks using a greedy algorithm. It processes the network's sharded tensor metadata - to create mappings between shard IDs, ranks, and parameter names, focusing only on shards within the - current parallelization group (replica_id == 0). - - Args: - network (cell): The neural network model whose shard distribution needs to be analyzed. - filter_func (func): Filter function that filters out parameters that do not need to be saved. - - Returns: - A tuple containing a dictionary mapping shard IDs to the rank that will save the shard - and a dictionary mapping shard IDs to parameter names. - """ - from mindspore.parallel.strategy import get_strategy_metadata - strategy_info = get_strategy_metadata(network) - - total_shard_metadata = get_total_shard_metadata(strategy_info, filter_func) - shard_to_ranks = defaultdict(list) - shard_to_size = {} - shards_in_this_parallelization_group = set() - shard_to_name = {} - - for rank, sharded_tensor_metas in enumerate(total_shard_metadata): - for tensor_meta in sharded_tensor_metas: - shard_id = _sharded_tensor_shard_id(tensor_meta.key, tensor_meta.global_offset) - shard_to_ranks[shard_id].append(rank) - - if shard_id not in shard_to_size: - shard_to_size[shard_id] = _get_shard_size(tensor_meta.local_shape, tensor_meta.dtype) - shard_to_name[shard_id] = tensor_meta.key - shards_in_this_parallelization_group.add(shard_id) - - shard_to_ranks = { - k: v - for k, v in shard_to_ranks.items() - if k in shards_in_this_parallelization_group - } - - shard_to_saving_rank = distribute_shards( - shard_to_ranks, shard_to_size, len(total_shard_metadata) - ) - return shard_to_saving_rank, shard_to_name - - def distribute_shards(shard_coverage, shard_sizes, total_ranks): """ Distribute shards to ranks using a greedy algorithm based on the following priority: @@ -392,3 +347,60 @@ def distribute_shards(shard_coverage, shard_sizes, total_ranks): rank_loads[selected_rank] += shard_sizes[shard_id] return shard_assignment + + +def apply_balance_shard_strategy(network: Cell, filter_func): + """ + Process and balance sharded tensor metadata across all ranks. + + This function retrieves strategy metadata from the network (and optimizer if provided), + processes sharding information, and distributes shards across ranks to generate balanced + sharded tensor metadata. If no strategy metadata exists, it falls back to directly extracting + sharded tensors from the network and optimizer. + + Args: + network (Cell): The MindSpore network cell containing parameters and sharding strategies. + optimizer (Optional[Optimizer]): Optional optimizer instance (if provided, filters out + accumulator gradient parameters from sharding metadata). + + Returns: + list: Balanced sharded tensor metadata for the current rank, either derived from + strategy metadata distribution or directly extracted from the network/optimizer. + + Notes: + - Relies on MindSpore's `get_strategy_metadata` for strategy-based sharding info. + - Filters out "accu_grads" parameters when an optimizer is provided to avoid redundant sharding. + - Falls back to direct tensor extraction if no strategy metadata is available. + """ + total_shard_metadata = get_all_sharded_tensor(network, filter_func) + shard_id_to_ranks = defaultdict(list) + shard_to_size = {} + shards_in_this_parallelization_group = set() + shard_id_to_tensor = {} + + for rank, sharded_tensor_metas in enumerate(total_shard_metadata): + for tensor_meta in sharded_tensor_metas: + shard_id = sharded_tensor_shard_id(tensor_meta.key, tensor_meta.global_offset) + shard_id_to_ranks[shard_id].append(rank) + + if shard_id not in shard_to_size: + shard_to_size[shard_id] = _get_shard_size(tensor_meta.local_shape, tensor_meta.dtype) + shard_id_to_tensor[shard_id] = tensor_meta + shards_in_this_parallelization_group.add(shard_id) + + shard_id_to_ranks = { + k: v + for k, v in shard_id_to_ranks.items() + if k in shards_in_this_parallelization_group + } + + shard_to_saving_rank = distribute_shards( + shard_id_to_ranks, shard_to_size, len(total_shard_metadata) + ) + + dst_sharded_tensor_metas = {} # {shard_name: ShardTensor} + local_rank = get_real_local_rank() + for shard_id, rank_id in shard_to_saving_rank.items(): + if rank_id == local_rank: + dst_sharded_tensor_metas[_reverse_sharded_tensor_shard_id(shard_id)[0]] = shard_id_to_tensor[shard_id] + return shard_id_to_ranks, shard_id_to_tensor, dst_sharded_tensor_metas diff --git a/mindformers/checkpoint/metadata.py b/mindformers/checkpoint/metadata.py index 5f056d187..caa6c7a3c 100644 --- a/mindformers/checkpoint/metadata.py +++ b/mindformers/checkpoint/metadata.py @@ -25,7 +25,7 @@ from mindspore.parallel import Layout from mindformers.tools.logger import logger from mindformers.tools.utils import set_safe_mode_for_file_or_dir -from mindformers.checkpoint.sharded_tensor import build_sharded_tensor, get_sharded_tensor_list_from_strategy_metadata +from mindformers.checkpoint.sharded_tensor import build_sharded_tensor from mindformers.checkpoint.utils import ( get_checkpoint_name, get_sharded_tensor_shard_id, @@ -274,7 +274,7 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict raise NotADirectoryError( f"Checkpoint directory '{checkpoint_dir}' does not exist or is not a directory.") - logger.info(f"..........Load Metadata from Checkpoint Files..........") + logger.info("..........Load Metadata from Checkpoint Files..........") # Find all safetensor files in the checkpoint directory safetensor_pattern = os.path.join(checkpoint_dir, "*.safetensors") @@ -309,9 +309,9 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict # Extract tensor properties tensor_shape = tensor.shape - ms_dtype = tensor_to_ms_type.get(tensor.dtype.__str__()) + ms_dtype = tensor_to_ms_type.get(str(tensor.dtype)) global_offset = (0,) - axis_fragmentations = [1] * len(tensor_shape) + axis_fragmentations = (1,) * len(tensor_shape) # Create sharded tensor metadata object sharded_tensor = build_sharded_tensor( @@ -337,37 +337,13 @@ def generate_default_metadata_from_checkpoint(checkpoint_dir: str) -> tuple[dict return sharded_tensor_metas, param_file_mappings -def get_total_shard_metadata(global_strategy_info, filter_func): - """Get all shard metadata.""" - npu_nums = get_group_size() - sharded_tensor_metas = list() - - for cur_npu_rank in range(0, npu_nums): - org_cur_rank_strategy_layout = global_strategy_info[cur_npu_rank] - cur_rank_strategy_layout = [ - dict([item]) - for item in org_cur_rank_strategy_layout.items() - ] - - # Get Sharded tensors from strategy metadata of current rank. - cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata( - param_infos=cur_rank_strategy_layout, - cur_npu_rank=cur_npu_rank, - filter_func=filter_func - ) - - sharded_tensor_metas.append(cur_rank_sharded_tensors) - - return sharded_tensor_metas - - def get_total_params_file_mapping_info(sharded_tensor_metas, user_prefix, model_keys): """Get all shard metadata file mappings list.""" if sharded_tensor_metas is None: return None npu_nums = get_group_size() - param_file_mappings = list() + param_file_mappings = [] for cur_npu_rank, cur_rank_sharded_tensor_list in enumerate(sharded_tensor_metas): # Get mappings of parameter file of current rank. for sharded_tensor in cur_rank_sharded_tensor_list: diff --git a/mindformers/checkpoint/sharded_tensor.py b/mindformers/checkpoint/sharded_tensor.py index 041c4c018..b190418f5 100644 --- a/mindformers/checkpoint/sharded_tensor.py +++ b/mindformers/checkpoint/sharded_tensor.py @@ -13,15 +13,17 @@ # limitations under the License. # ============================================================================ """Sharded Tensor""" - from collections import defaultdict from dataclasses import dataclass from typing import List, Dict, Optional, Tuple, Union, Callable import mindspore as ms +from mindspore.communication import get_group_size from mindspore.nn import Cell from mindspore.parallel.shard import _DistributedTensorInfo +from mindspore.parallel.strategy import get_current_strategy_metadata, get_strategy_metadata +from mindformers.tools.utils import get_real_rank from mindformers.tools.logger import logger @@ -202,7 +204,7 @@ def _tensor_map_with_rank_id(cur_dev_matrix, flat_tensor_map, cur_alias_name, de return alias_rank_stride -def _rank_id_with_slice_id(alias_rank_stride) -> List: +def _rank_id_with_slice_id(alias_rank_stride): """Rank_id vs Slice_id, rank_id ranges from 0 to dev_num.""" cur_global_offset: Tuple[int, ...] = () rank_num = len(alias_rank_stride[0][0]) @@ -438,7 +440,7 @@ def get_sharded_tensor_list_from_cell( param_dtype = param.data.dtype param_shape = param.data.shape global_offset = (0,) - axis_fragmentations = [1] * len(param_shape) + axis_fragmentations = (1,) * len(param_shape) # Create and add sharded tensor metadata sharded_tensor = build_sharded_tensor( @@ -491,3 +493,45 @@ def convert_sharded_tensor_list_to_dict( sharded_tensor_dict[param_name] = sharded_tensor return sharded_tensor_dict + + +def get_all_sharded_tensor(network, filter_func) -> list: + """Get all rank sharded tensors.""" + logger.info(".........Get All Ranks' Strategy Metadata.........") + global_strategy_info = get_strategy_metadata(network) + if not global_strategy_info: + raise RuntimeError('`get_strategy_metadata` returns `None`, which indicates there is no strategy info. ' + 'Please check whether this is a distributed job.') + + npu_nums = get_group_size() + sharded_tensor_metas = [] + + for cur_npu_rank in range(0, npu_nums): + org_cur_rank_strategy_layout = global_strategy_info[cur_npu_rank] + cur_rank_strategy_layout = [ + dict([item]) + for item in org_cur_rank_strategy_layout.items() + ] + + # Get Sharded tensors from strategy metadata of current rank. + cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata( + param_infos=cur_rank_strategy_layout, + cur_npu_rank=cur_npu_rank, + filter_func=filter_func + ) + + sharded_tensor_metas.append(cur_rank_sharded_tensors) + return sharded_tensor_metas + + +def get_cur_sharded_tensor(network, filter_func): + """Get current rank sharded tensors.""" + logger.info(".........Get Current Strategy Metadata.........") + strategy_info = get_current_strategy_metadata(network) + # Convert strategy layout to required format + strategy_info = [dict([item]) for item in strategy_info[0].items()] + # Get sharded tensors from strategy metadata + cur_rank_sharded_tensors = get_sharded_tensor_list_from_strategy_metadata( + param_infos=strategy_info, cur_npu_rank=get_real_rank(), filter_func=filter_func + ) + return cur_rank_sharded_tensors diff --git a/mindformers/checkpoint/utils.py b/mindformers/checkpoint/utils.py index f6745008e..d0c6c0ae2 100644 --- a/mindformers/checkpoint/utils.py +++ b/mindformers/checkpoint/utils.py @@ -117,7 +117,7 @@ def get_checkpoint_iter_dir(checkpoints_path: str, iteration: int) -> str: if not isinstance(iteration, int): raise ValueError(f"'iteration' must be an integer! But got '{type(iteration)}'.") - directory = PER_ITERATION_CKPT_DIR_PREFIX + '{:08d}'.format(iteration) + directory = f'{PER_ITERATION_CKPT_DIR_PREFIX}{iteration:08d}' iter_dir = os.path.join(checkpoints_path, directory) return iter_dir @@ -166,7 +166,7 @@ def get_metadata_filename(checkpoints_path: str, iteration: int) -> str: return os.path.join(metadata_path, 'metadata.json') -def get_latest_iteration_from_tracker(checkpoints_path: str) -> bool: +def get_latest_iteration_from_tracker(checkpoints_path: str) -> int: """ Get the iteration tracker file content. Used in resume scene. @@ -186,7 +186,7 @@ def get_latest_iteration_from_tracker(checkpoints_path: str) -> bool: raise FileNotFoundError(f"No tracker file found in load directory: '{tracker_filename}'.") # Get the latest iteration number from the tracker file. - with open(tracker_filename, 'r') as f: + with open(tracker_filename, 'r', encoding="utf-8") as f: iter_string = f.read().strip() try: iteration = int(iter_string) @@ -204,14 +204,14 @@ def get_latest_iteration_from_tracker(checkpoints_path: str) -> bool: return iteration -def get_checkpoint_name(cur_iter_checkpoint_dir: str, user_prefix: str, file_idx: int, total_file_num: int, - file_type: FileType) -> str: +def get_checkpoint_name(cur_iter_checkpoint_dir: Optional[str], user_prefix: Optional[str], file_idx: int, + total_file_num: int, file_type: FileType) -> str: """ Generate a checkpoint name for model parameters or optimizer parameters. Args: - cur_iter_checkpoint_dir (str): Currently iteration checkpoint path. - user_prefix (str): The prefix to use for the checkpoint file name. + cur_iter_checkpoint_dir (Optional[str]): Currently iteration checkpoint path. + user_prefix (Optional[str]): The prefix to use for the checkpoint file name. file_idx (int): The index of the current file. total_file_num (int): The total number of files. file_type (str): The type of the file (e.g., model parameters, optimizer parameters). @@ -245,7 +245,7 @@ def get_sharded_tensor_shard_id(param_name, global_offset): return str(tuple((param_name, tuple(global_offset)))) -def _sharded_tensor_shard_id(param_name, global_offset): +def sharded_tensor_shard_id(param_name, global_offset): """ Generate a unique identifier for a sharded tensor based on its parameter name and global offset. @@ -321,7 +321,7 @@ def verify_ckpt_valid(checkpoint_dir: str) -> Optional[str]: checkpoint_dir: Path to the checkpoint directory to validate. Returns: - Optional[str]: `None` if validation passes.. + Optional[str]: `None` if validation passes. Raises: NotADirectoryError: If `checkpoint_dir` does not exist or is not a directory. diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index 7cfac6b61..60f272904 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -58,6 +58,7 @@ from mindspore.communication.comm_func import all_gather_into_tensor, barrier from mindspore.profiler import ProfilerLevel, schedule from mindspore.utils import stress_detect +from mindformers.checkpoint.sharded_tensor import get_all_sharded_tensor from mindformers.core.context.build_context import is_legacy_model from mindformers.tools import get_output_root_path from mindformers.tools.logger import logger @@ -195,26 +196,26 @@ def _get_max_eigenvalue(input_tensor, num_iter): Calculate max eigenvalue https://www.cnblogs.com/zxn-share/p/17392450.html """ - input_tensor = input_tensor.astype(ms.float32) # (m,n) or (b,m,n) - in_features = input_tensor.shape[-1] # (n) + input_tensor = input_tensor.astype(ms.float32) # (m,n) or (b,m,n) + in_features = input_tensor.shape[-1] # (n) u_tensor = None for _ in range(5): - u_tensor = ms.ops.randn(in_features) # (n) + u_tensor = ms.ops.randn(in_features) # (n) u_norm = u_tensor.norm() if u_norm.asnumpy() > 0: break else: logger.warning("Calculate max eigenvalue: the norm of a randomly generated vector is 0") return 0.0 - u_tensor = u_tensor / u_tensor.norm() # (n) + u_tensor = u_tensor / u_tensor.norm() # (n) input_seq = ms.ops.matmul(input_tensor.transpose(-2, -1), input_tensor) # (n.n) or (b,n,n) if input_tensor.ndim == 2: input_seq = ms.ops.unsqueeze(input_seq, 0) # (1,n,n) - u_tensor = ms.ops.unsqueeze(u_tensor, 1) # (n,1) + u_tensor = ms.ops.unsqueeze(u_tensor, 1) # (n,1) for _ in range(num_iter): - v_tensor = ms.ops.matmul(input_seq, u_tensor) # (b,n,n) * (n,1) = (b,n,1) + v_tensor = ms.ops.matmul(input_seq, u_tensor) # (b,n,n) * (n,1) = (b,n,1) eigenvalue = ms.ops.matmul(v_tensor.transpose(-2, -1), u_tensor).squeeze() # (b,1,n) * (b,n,1) = b - v_norm = v_tensor.norm(dim=1, keepdim=True) # (b,1,1) + v_norm = v_tensor.norm(dim=1, keepdim=True) # (b,1,1) if (v_norm != 0).all(): u_tensor = v_tensor / v_norm else: @@ -753,6 +754,7 @@ class TrainingStateMonitor(Callback): use_local_norm (bool, optional): Whether to turn on the local norm. Default: ``False``. Default: ``False``. """ + @args_type_check(embedding_size=int, use_skip_data_by_global_norm=bool) def __init__(self, origin_epochs: int, @@ -1299,7 +1301,7 @@ class TrainingStateMonitor(Callback): if 'tensorboard' in self.max_attention_logit_format: tp_id = get_rank() // self.tensor_model_parallel_size head_start = tp_id * len(v) - data = {f"head_{head_start+i}": max_attention_logit for i, max_attention_logit in enumerate(v)} + data = {f"head_{head_start + i}": max_attention_logit for i, max_attention_logit in enumerate(v)} self._output(tag, data, step, ['tensorboard']) vals.append(v) @@ -1579,7 +1581,7 @@ class CheckpointMonitor(ModelCheckpoint): self.health_ckpts_record_dir = health_ckpts_record_dir self.use_legacy_format = use_legacy_format # Ensure that 'save_optimizer' only use in the sense of 'use_legacy_format == False' - self.save_optimizer = save_optimizer if not use_legacy_format else None + self.save_optimizer = save_optimizer if not use_legacy_format else False self.origin_prefix = prefix self.save_checkpoint_path = save_checkpoint_path self.need_remove_redundancy = remove_redundancy @@ -1691,7 +1693,7 @@ class CheckpointMonitor(ModelCheckpoint): for record_step in keys: self.print_savetime(record_step, cb_params.batch_num) if not any(self.save_info_list[record_step][key]['ckpt_file_path'] - for key in ['ckpt', 'network', 'trainable_params']): + for key in ['ckpt', 'network', 'trainable_params']): self.save_info_list.pop(record_step) if self._config.async_save and not ms.async_ckpt_thread_status() and \ @@ -1853,6 +1855,7 @@ class CheckpointMonitor(ModelCheckpoint): def _tft_save_ckpt(self, param_layout_set, save_param_names, cur_file, append_dict, network): """save checkpoint with remove redundancy for TFT training.""" + def choice_func(x): return (x not in param_layout_set or (save_param_names is not None and x in save_param_names)) \ and self._filter_ckpt_not_save(x, self.filter_list) @@ -2036,9 +2039,12 @@ class CheckpointMonitor(ModelCheckpoint): self.common_info.loss_scale = float(cb_params.net_outputs[2]) self.common_info.global_batch_size = self.global_batch_size - from mindspore.parallel.strategy import get_strategy_metadata - # Get all strategy info of this network to save 'metadata.json' - global_strategy_info = get_strategy_metadata(network=cb_params.network) + # Get all sharded tensor info of this network to save 'metadata.json' + sharded_tensor_metas = get_all_sharded_tensor( + network=cb_params.network, + filter_func=(lambda x: x in list( + cb_params.network.network.parameters_dict().keys())) if not self.save_optimizer else None + ) save_checkpoint( iteration=iteration, @@ -2049,7 +2055,7 @@ class CheckpointMonitor(ModelCheckpoint): keep_max_num=self._config.keep_checkpoint_max, user_prefix=self.origin_prefix, save_checkpoint_path=self.save_checkpoint_path, - global_strategy_info=global_strategy_info, + sharded_tensor_metas=sharded_tensor_metas, remove_redundancy=self.need_remove_redundancy ) @@ -2684,7 +2690,6 @@ class StressDetectCallBack(Callback): logger.warning(f"detection_interval = {self.detection_interval} is bigger than " f"steps_per_epoch = {self.steps_per_epoch}") - def on_train_step_end(self, run_context): """ Stress detect at the end of step. @@ -2705,7 +2710,6 @@ class StressDetectCallBack(Callback): self.log_stress_detect_result(detect_ret_list) - @staticmethod def log_stress_detect_result(detect_ret_list): """print output information.""" @@ -2725,10 +2729,6 @@ class MaxLogitsMonitor(Callback): This callback resets the maximum attention logit values at the end of each training step. """ - - def __init__(self,): - pass - def _reset_max_attention_logit(self, network): """Reset max attention logit in the network. @@ -2773,6 +2773,7 @@ class TopkBiasBalanceCallback(Callback): micro_batch_num (int, optional): Micro batch number in pipeline parallel. Default to 1. gradient_accumulation_steps (int, optional): Gradient accumulation steps for training. Default to 1. """ + def __init__(self, balance_via_topk_bias: bool = False, topk_bias_update_rate: float = 0.0, @@ -2875,6 +2876,7 @@ class MoEDropRateCallback(Callback): >>> stop_step = MoEDropRateCallback(expert_num=8, capacity_factor=1.5, num_layers=4, mtp_depth=1) """ + def __init__(self, expert_num: int, capacity_factor: float, @@ -2945,6 +2947,7 @@ class StressTestModelMonitor(Callback): stress_test_log_dir (str optional): The directory where the stress test training log is stored. check_stresslog_interval_time (int, optional): Time interval where the stress test log is checked. """ + def __init__(self, interval_steps=10, stress_model_dir=None, @@ -2972,7 +2975,7 @@ class StressTestModelMonitor(Callback): self.stress_master_port = 8338 if self.stress_master_port == self.main_master_port: logger.warning("For StressTestMonitor, stress_master_port must be different from the main task " - f"but both got {self.stress_master_port}. Setting to {self.stress_master_port+1}") + f"but both got {self.stress_master_port}. Setting to {self.stress_master_port + 1}") self.stress_master_port += 1 logger.warning(f"Make sure that the new port {self.stress_master_port} is unoccupied.") self.worker_num = ms.communication.get_local_rank_size() @@ -3016,8 +3019,8 @@ class StressTestModelMonitor(Callback): rank_id = get_rank() if rank_id % self.worker_num == 0: - node_num = rank_id//self.worker_num - saved_dir = os.path.join(self.stress_test_log_dir, "node"+str(node_num)) + node_num = rank_id // self.worker_num + saved_dir = os.path.join(self.stress_test_log_dir, "node" + str(node_num)) command = f"""taskset -c {cpu_cores} bash scripts/msrun_launcher.sh "run_mindformer.py \ --config {self.model_dir} \ --use_parallel True\ @@ -3085,8 +3088,8 @@ class StressTestModelMonitor(Callback): last_step_results = self.extract_last_step_result(log_file_path) barrier() - gathered_results, _ = all_gather_into_tensor(last_step_results) # - gathered_results = gathered_results.asnumpy() # + gathered_results, _ = all_gather_into_tensor(last_step_results) # + gathered_results = gathered_results.asnumpy() # logger.debug("Collected last step results are gathered_results.") logger.info("Last step results are collected from each rank, now starting to compare last step results") @@ -3141,7 +3144,7 @@ class StressTestModelMonitor(Callback): global_step_number = (epoch_number - 1) * steps_per_epoch + step_number # Consider logging only if it matches the interval - if global_step_number >= (self.compare_interval_steps+last_recorded_step): + if global_step_number >= (self.compare_interval_steps + last_recorded_step): loss_value = self.get_value_from_line(line, r"loss: (\d+\.\d+)") global_norm_value = self.get_value_from_line(line, r"global_norm: \[(\d+\.\d+)\]") results.append(Tensor([[epoch_number, step_number, loss_value, global_norm_value]], ms.float32)) @@ -3240,6 +3243,7 @@ class SDCMonitor(Callback): checksum_cooldown_time (int, optional): The cooldown time (minutes) of CheckSum after it stops. Default: ``180``. """ + def __init__(self, initial_step: int = 0, step_interval: int = 10, @@ -3258,12 +3262,12 @@ class SDCMonitor(Callback): self.initial_step = initial_step self.step_interval = step_interval - self.step_times = {datetime.now(): initial_step} # {timestamp: step} - self.silent_check_error_times = {} # {timestamp: step} + self.step_times = {datetime.now(): initial_step} # {timestamp: step} + self.silent_check_error_times = {} # {timestamp: step} self.strike_window_time = timedelta(minutes=strike_window_time) self.strike_num = strike_num self.checksum_enable = False - self.prev_checksum_time = datetime.min # start/stop time + self.prev_checksum_time = datetime.min # start/stop time self.checksum_time = timedelta(minutes=checksum_time) self.checksum_cooldown_time = timedelta(minutes=checksum_cooldown_time) @@ -3277,7 +3281,7 @@ class SDCMonitor(Callback): self.log_time_pattern = re.compile(r'(\d{4}-\d{2}-\d{2}-\d{2}:\d{2}:\d{2}\.\d{3}\.\d{3})') logger.info(f"Device log path: {self.device_log_path}, pid: {pid}") - self.all_reduce_net = AllReduceNet(GlobalComm.WORLD_COMM_GROUP) # AllReduce status and result of CheckSum + self.all_reduce_net = AllReduceNet(GlobalComm.WORLD_COMM_GROUP) # AllReduce status and result of CheckSum def _get_log_files_to_check(self): """Get device log filenames after last check and sort them by timestamp.""" @@ -3313,7 +3317,7 @@ class SDCMonitor(Callback): if not error_log_times: return {} # process from latest to earliest, stop early if error num reaches strike num - error_times = {} # {timestamp: step} + error_times = {} # {timestamp: step} step_time_list = list(self.step_times.keys()) index = len(step_time_list) - 1 for log_time in reversed(error_log_times): @@ -3328,7 +3332,7 @@ class SDCMonitor(Callback): logger.warning(f"SilentCheck detect SDC at step: {step}") error_times[log_time] = step index -= 1 - return dict(reversed(list(error_times.items()))) # order from earliest to latest + return dict(reversed(list(error_times.items()))) # order from earliest to latest def _update_silent_check_error_times(self, new_silent_check_error_times, now): """Add new SilentCheck error times and remove expired ones.""" diff --git a/mindformers/core/config_args.py b/mindformers/core/config_args.py index 82652d557..ea3f0f46a 100644 --- a/mindformers/core/config_args.py +++ b/mindformers/core/config_args.py @@ -488,7 +488,7 @@ class MFContextConfig(BaseArgsConfig): 'src_strategy_path_or_dir', 'auto_trans_ckpt', 'only_save_strategy', - 'ckpt_use_legacy_format', + 'use_legacy_format', 'balanced_load', 'run_mode', 'use_legacy', diff --git a/mindformers/tools/register/template.py b/mindformers/tools/register/template.py index 1a875da1c..f34ad613f 100644 --- a/mindformers/tools/register/template.py +++ b/mindformers/tools/register/template.py @@ -250,7 +250,7 @@ class GeneralConfig(Config): only_save_strategy = False load_ckpt_async = False use_legacy = True - ckpt_use_legacy_format = True + use_legacy_format = True pretrained_model_dir = "" balanced_load = False diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index eae453e4a..7c43e20ba 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -1128,7 +1128,7 @@ class BaseTrainer: logger.info("Create train dataset finish, dataset size:%d", dataset.get_dataset_size()) append_info = None - if not config.ckpt_use_legacy_format: + if not config.use_legacy_format: if config.resume_training and config.load_checkpoint and not check_is_reboot_node(): logger.info(".............Start load resume context from common.json..................") common_file = os.path.join(config.load_checkpoint, 'common.json') @@ -1217,17 +1217,13 @@ class BaseTrainer: check_rules(config, mode='train', network=network, dataset=dataset) # It is necessary to enable save strategy online before the model compilation phase, - # if save checkpoint with megatron format. - save_checkpoint_with_legacy_format = True - for callback in self.config.callbacks: - if "type" in callback and callback["type"] == "CheckpointMonitor": - save_checkpoint_with_legacy_format = callback.get("use_legacy_format", True) - if not save_checkpoint_with_legacy_format or not config.ckpt_use_legacy_format: + # if load or save checkpoint with megatron format. + if not config.get('use_legacy_format', True): try: from mindspore.parallel.strategy import enable_save_strategy_online enable_save_strategy_online() except ImportError: - logger.warning("If you want to save checkpoint with new format," + logger.warning("If you want to save or load checkpoint with new format," "please ensure your version of mindspore > 2.7.1, " "to support 'enable_save_strategy_online'.") @@ -1373,7 +1369,8 @@ class BaseTrainer: "embedding_size": embedding_size, "embedding_local_norm_threshold": embedding_local_norm_threshold, "use_checkpoint_health_monitor": use_checkpoint_health_monitor, - "health_ckpts_record_dir": config.output_dir + "health_ckpts_record_dir": config.output_dir, + "use_legacy_format": config.get('use_legacy_format', True) } if not config.get("use_legacy", True) and default_args.get("checkpoint_format") == "ckpt": logger.warning( @@ -1461,7 +1458,7 @@ class BaseTrainer: model = Model(network, optimizer=optimizer, metrics=compute_metrics, eval_network=eval_network) # resume checkpoint - if not config.ckpt_use_legacy_format and config.load_checkpoint: + if not config.use_legacy_format and config.load_checkpoint: if config.use_parallel: compile_model(model, dataset, mode=config.context.mode, sink_mode=config.runner_config.sink_mode, epoch=config.runner_config.epochs, sink_size=config.runner_config.sink_size) @@ -1477,11 +1474,18 @@ class BaseTrainer: f"(batch size changed from {common_info.global_batch_size} to {self.global_batch_size})" ) load_checkpoint( - checkpoint=config.load_checkpoint, network=network, optimizer=optimizer, global_step=global_step, + checkpoint=config.load_checkpoint, + network=model.train_network, + optimizer=optimizer, + global_step=global_step, balanced_load=config.balanced_load ) else: - load_checkpoint(checkpoint=config.load_checkpoint, network=network, balanced_load=config.balanced_load) + load_checkpoint( + checkpoint=config.load_checkpoint, + network=model.train_network, + balanced_load=config.balanced_load + ) elif (config.load_checkpoint or config.only_save_strategy) and not check_is_reboot_node(): if config.resume_training: logger.info(".............Start resume training from checkpoint..................") diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 8b1bafc4d..7e0f6471c 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -470,7 +470,7 @@ class Trainer: self._check_config_rules() self._init_model(is_train=True) # if enable_tft is False or remove_redundancy is True, can use record last_ckpt to json - if self.config.ckpt_use_legacy_format and self.config.resume_training and \ + if self.config.use_legacy_format and self.config.resume_training and \ (not check_tft_valid() or self.config.remove_redundancy): if os.path.isfile(self.config.load_checkpoint) and \ isinstance(self.config.resume_training, str): @@ -491,7 +491,7 @@ class Trainer: health_ckpts_record_dir=self.config.output_dir ) - if not self.config.ckpt_use_legacy_format: + if not self.config.use_legacy_format: self.config.load_checkpoint = get_checkpoint_path(self.config.load_checkpoint) else: self.config.load_checkpoint = self.get_load_checkpoint(self.config.load_checkpoint) @@ -614,7 +614,7 @@ class Trainer: self._check_config_rules() self._init_model(is_train=True) - if self.config.ckpt_use_legacy_format and self.config.resume_training: + if self.config.use_legacy_format and self.config.resume_training: if os.path.isfile(self.config.load_checkpoint) and \ isinstance(self.config.resume_training, str): logger.warning(f"`resume_training={self.config.resume_training}` is not valid " @@ -633,7 +633,7 @@ class Trainer: health_ckpts_record_dir=self.config.output_dir ) - if not self.config.ckpt_use_legacy_format: + if not self.config.use_legacy_format: self.config.load_checkpoint = get_checkpoint_path(self.config.load_checkpoint) else: self.config.load_checkpoint = self.get_load_checkpoint(self.config.load_checkpoint) @@ -1430,7 +1430,7 @@ class Trainer: logger.warning("The functionality of setting `resume_training` to a weight filename " "will be deprecated in future versions.") - if not self.config.ckpt_use_legacy_format and not isinstance(self.config.resume_training, bool): + if not self.config.use_legacy_format and not isinstance(self.config.resume_training, bool): raise ValueError("The resume_training must be a boolean value.") if isinstance(self.config.resume_training, str) and \ diff --git a/tests/st/test_ut/base_schema.json b/tests/st/test_ut/base_schema.json index 94ddfee58..326d24fca 100644 --- a/tests/st/test_ut/base_schema.json +++ b/tests/st/test_ut/base_schema.json @@ -1,6 +1,6 @@ { "mindformers.checkpoint.save_checkpoint": { - "signature": "(iteration: int, network: mindspore.nn.cell.Cell, optimizer: mindspore.nn.optim.optimizer.Optimizer = None, async_save_manager: mindformers.checkpoint.checkpoint.AsyncSaveManager = None, common_info: mindformers.checkpoint.checkpoint.CommonInfo = None, keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None, global_strategy_info: List[Dict] = None, remove_redundancy: bool = False)" + "signature": "(iteration: int, network: mindspore.nn.cell.Cell, optimizer: mindspore.nn.optim.optimizer.Optimizer = None, async_save_manager: mindformers.checkpoint.checkpoint.AsyncSaveManager = None, common_info: mindformers.checkpoint.checkpoint.CommonInfo = None, keep_max_num: int = 5, user_prefix: str = None, save_checkpoint_path: str = None, sharded_tensor_metas: list = None, remove_redundancy: bool = False)" }, "mindformers.core.AdamW": { "signature": "(params, learning_rate=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, use_fused=False, amsgrad=False, maximize=False, swap=False)" diff --git a/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py b/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py index 28aefeaba..41395f1d6 100644 --- a/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py +++ b/tests/st/test_ut/test_utils/test_tensorboard/test_tensorboard.py @@ -48,7 +48,7 @@ _CHECK_TEXT_MAPPING = { 'eval_epoch_interval', 'eval_dataset', 'eval_dataset_task', 'lr_schedule', 'metric', 'model', 'moe_config', 'optimizer', 'parallel_config', 'parallel', 'recompute_config', 'remove_redundancy', 'runner_config', 'runner_wrapper', 'monitor_config', 'tensorboard', 'train_dataset_task', 'train_dataset', 'trainer', - 'swap_config', 'use_legacy', 'pretrained_model_dir', 'print_separate_loss', 'ckpt_use_legacy_format', + 'swap_config', 'use_legacy', 'pretrained_model_dir', 'print_separate_loss', 'use_legacy_format', 'balanced_load' } @@ -59,8 +59,7 @@ def generator_train(): batch_size = 1 vocab_size = 32000 input_ids = np.random.randint(low=0, high=vocab_size, size=(step_num * batch_size, seq_len,)).astype(np.int32) - for idx, _ in enumerate(input_ids): - yield input_ids[idx] + yield from input_ids class TestTensorBoard: """A test class for testing pipeline.""" -- Gitee