diff --git a/mindformers/checkpoint/checkpoint.py b/mindformers/checkpoint/checkpoint.py index 1f49ea6a1ad71692963308505366f69c48514d3f..71356b0bd00c619e11991ce3fc43be4ce0f33710 100644 --- a/mindformers/checkpoint/checkpoint.py +++ b/mindformers/checkpoint/checkpoint.py @@ -567,6 +567,7 @@ def load_tensor_by_offset( checkpoint_dir: str, src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], param_file_mappings: Dict[str, List[Dict[str, Any]]], + key_mapping: Dict[str, str], ) -> Dict[int, Parameter]: """ Loads specific tensor slices from checkpoint files based on offset information. @@ -581,16 +582,18 @@ def load_tensor_by_offset( checkpoint_dir: Directory containing the checkpoint files src_sharded_tensor_metas: Metadata for source sharded tensors param_file_mappings: Mapping of parameters to their storage files + key_mapping: Mapping of `original key` in checkpoint to `param key` in network. Returns: Dictionary mapping ranks to their corresponding loaded Parameter objects """ + def _get_storage_info_of_sharded_tensor( sharded_tensor: ShardedTensor, param_file_mappings: Dict[str, List[Dict[str, Any]]] ) -> List[Dict[str, Any]]: """Retrieves storage information for a specific sharded tensor.""" - param_key = str((sharded_tensor.key, sharded_tensor.global_offset)) + param_key = str((sharded_tensor.org_key, sharded_tensor.global_offset)) return param_file_mappings[param_key] def _get_storage_rank_dict_of_param( @@ -600,14 +603,14 @@ def load_tensor_by_offset( ) -> Dict[int, Tuple[str, Any]]: """Creates a dictionary mapping storage ranks to their file and dtype information.""" storage_rank_dict: Dict[int, Tuple[str, Any]] = {} + if param_name not in sharded_tensor_metas: + param_name = key_mapping[param_name] for sharded_tensor in sharded_tensor_metas[param_name]: storage_info_list = _get_storage_info_of_sharded_tensor(sharded_tensor, param_file_mappings) - for storage_info in storage_info_list: storage_rank = storage_info["storage_rank"] storage_rank_dict[storage_rank] = (storage_info["file_name"], sharded_tensor.dtype) - return storage_rank_dict # Get storage rank information for the parameter @@ -666,7 +669,7 @@ def categorize_params( dst_sharded_tensor_metas: Dict[str, List[ShardedTensor]], src_sharded_tensor_metas: Dict[str, List[ShardedTensor]], param_file_mappings: Dict[str, List[Dict[str, Any]]] -) -> Tuple[List[str], Dict[str, Dict[str, List[Any]]], Dict[str, List[Any]]]: +) -> Tuple[List[str], Dict[str, Dict[str, List[Any]]], Dict[str, Dict[str, List[Any]]], Dict[str, List[Any]]]: """ Categorizes parameters based on comparison of source and destination sharding strategies. @@ -692,7 +695,8 @@ def categorize_params( RuntimeError: If sharding strategies match but no corresponding parameter offset is found """ # Initialize categorization collections - special_params: List[str] = [] + not_mapping_params: List[str] = [] + need_concat_params: Dict[str, Dict[str, List[Any]]] = {} no_shard_params: Dict[str, Dict[str, List[Any]]] = {} no_shard_params_list: List[str] = [] online_shard_params: Dict[str, List[Any]] = {} @@ -703,9 +707,14 @@ def categorize_params( for param_name in dst_sharded_tensor_metas: # Handle parameters missing from source metadata if param_name not in src_sharded_tensor_metas: - special_params.append(param_name) + not_mapping_params.append(param_name) continue + # Get destination tensor strategy information + dst_sharded_tensor = dst_sharded_tensor_metas[param_name] + dst_global_shape, dst_axis_fragmentations, dst_global_offset = get_strategy_info_from_sharded_tensor( + dst_sharded_tensor) + src_sharded_tensor_list = src_sharded_tensor_metas[param_name] if not src_sharded_tensor_list: raise ValueError( @@ -713,10 +722,25 @@ def categorize_params( "Valid source metadata requires at least one ShardedTensor entry." ) - # Get destination tensor strategy information - dst_sharded_tensor = dst_sharded_tensor_metas[param_name] - dst_global_shape, dst_axis_fragmentations, dst_global_offset = \ - get_strategy_info_from_sharded_tensor(dst_sharded_tensor) + # Get parameters info which need to concat + if param_name != src_sharded_tensor_list[0].key: + concat_infos = [] + reshard_infos = [] + for src_sharded_tensor in src_sharded_tensor_list: + param_key = str((src_sharded_tensor.org_key, src_sharded_tensor.global_offset)) + concat_infos.append( + { + 'sub_name': src_sharded_tensor.org_key, + 'file_name': param_file_mappings[param_key][0]["file_name"], + 'param_dtype': src_sharded_tensor.dtype, + } + ) + + if dst_axis_fragmentations != src_sharded_tensor_list[0].axis_fragmentations: + # `reshard_infos` contains `full_shape, from_layout, to_layout, to_rank_id` + reshard_infos = [dst_global_shape, None, dst_sharded_tensor.layout, rank_id] + need_concat_params[param_name] = (concat_infos, reshard_infos) + continue param_key: Optional[str] = None strategy_is_same = False @@ -738,7 +762,7 @@ def categorize_params( # Check if offsets match for direct mapping if src_global_offset == dst_global_offset: - param_key = str((param_name, src_global_offset)) + param_key = str((src_sharded_tensor.org_key, src_global_offset)) break # Found matching parameter # Validate strategy consistency @@ -755,26 +779,27 @@ def categorize_params( # Initialize entry if new file if file_name not in no_shard_params: no_shard_params[file_name] = { - "param_name_list": [param_name], + "param_name_list": [src_sharded_tensor.org_key], "param_dtype_list": [src_sharded_tensor.dtype], } else: # Add to existing file entry - no_shard_params[file_name]["param_name_list"].append(param_name) + no_shard_params[file_name]["param_name_list"].append(src_sharded_tensor.org_key) no_shard_params[file_name]["param_dtype_list"].append(src_sharded_tensor.dtype) - no_shard_params_list.append(param_name) + no_shard_params_list.append(src_sharded_tensor.org_key) else: # Parameters that need online resharding - online_shard_params[param_name] = [ + online_shard_params[src_sharded_tensor.org_key] = [ dst_global_shape, src_sharded_tensor.layout, dst_sharded_tensor.layout, rank_id ] - - logger.debug(f"Params needing transformation: {special_params}") + # Parameters to be processed for categorized logging + logger.debug(f"Params not mapping: {not_mapping_params}") + logger.debug(f"Params needing transformation: {need_concat_params}") logger.debug(f"Params no need reshard: {no_shard_params_list}") logger.debug(f"Params need reshard: {list(online_shard_params.keys())}") - return special_params, no_shard_params, online_shard_params + return not_mapping_params, need_concat_params, no_shard_params, online_shard_params def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]: @@ -787,7 +812,8 @@ def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]: by parsing the checkpoint files directly using load_metadata_from_checkpoint(). Args: - checkpoint_dir: Path to the directory containing the checkpoint files + checkpoint_dir: Path to the directory containing the checkpoint files. + network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. Returns: A tuple containing two dictionaries: @@ -812,6 +838,63 @@ def get_metadata_of_checkpoint(checkpoint_dir: str) -> tuple[dict, dict]: return sharded_tensor_metas, param_file_mappings +def params_key_mapping( + sharded_tensor_metas: Dict[str, List[ShardedTensor]], + network: Cell +) -> tuple[dict, dict, Cell]: + """ + Mapping Hugging Face checkpoint keys to MindSpore Transformers. + + Args: + sharded_tensor_metas: Metadata about sharded tensors. + network: The target core network (Cell) which has method `convert_name` to convert Hugging Face weight. + + Returns: + A dictionary after mapping about sharded tensor metas. + """ + + # pylint: disable=W0212 + def get_core_network(network): + """Get the core network that has `convert_name` method.""" + if hasattr(network, 'convert_name'): + return network + if hasattr(network, '_backbone'): + return get_core_network(network._backbone) + if hasattr(network, 'network'): + return get_core_network(network.network) + raise NotImplementedError("Network has no function `convert_name`.") + + # Get the core network and check the convert method is illegal + core_network = get_core_network(network) + if not hasattr(core_network, 'weight_mapping'): + raise NotImplementedError("The `weight_mapping` of network is not implemented.") + if not hasattr(core_network, 'convert_hf_weight'): + raise NotImplementedError("The `convert_hf_weight` method of network is not implemented.") + + # The key of `mapped_sharded_tensor_metas` is in the network, + # such as { qkv: [ShardedTensor, ShardedTensor, ShardedTensor], ... } + mapped_sharded_tensor_metas = {} + # The key of `key_mapping` is {'weight_key': 'mapping_key'}, + # and the `mapping_key` may not have the same name as the parameter in the network, + # it could be an intermediate form, + # such as { 'q_proj': 'linear_q', 'k_proj': 'linear_k', 'v_proj': 'linear_v', ... } + key_mapping = {} + + for param_name in sharded_tensor_metas: + param_name_converted = core_network.convert_name(param_name) + sharded_tensor_list = sharded_tensor_metas.get(param_name) + + for sharded_tensor in sharded_tensor_list: + sharded_tensor.key = param_name_converted + sharded_tensor.org_key = param_name + + key_mapping[param_name] = param_name_converted + param_name_converted_concat = core_network.convert_concat_name(param_name_converted) + mapped_sharded_tensor_metas.setdefault(param_name_converted_concat, []).extend(sharded_tensor_list) + + return mapped_sharded_tensor_metas, key_mapping, core_network + + def load_checkpoint( checkpoint: str, network: Cell, @@ -837,15 +920,7 @@ def load_checkpoint( (Other exceptions may be raised by dependent functions for checkpoint validation/loading) """ # Validate mandatory network parameter - if network is None: - raise ValueError("The 'network' cannot be None - a target network is required for loading.") - - if balanced_load: - raise ValueError("The balanced loading strategy is not supported yet. " - "`balanced_load` is a preset switch, please set `balanced_load` to False.") - - if not os.path.exists(checkpoint): - raise ValueError(f"Checkpoint does not exist: {checkpoint}") + check_the_param_for_load_ckpt(balanced_load, checkpoint, network) # Determine checkpoint directory path checkpoint_dir = get_checkpoint_path(checkpoint) @@ -854,21 +929,32 @@ def load_checkpoint( # Retrieve metadata from checkpoint files src_sharded_tensor_metas, param_file_mappings = get_metadata_of_checkpoint(checkpoint_dir) + # Mapping the weight keys, which is used to determine whether to load the Hugging Face weights. + + try: + src_sharded_tensor_metas, key_mapping, core_network = params_key_mapping(src_sharded_tensor_metas, network) + # Validate the returned values + if not isinstance(src_sharded_tensor_metas, dict) or not isinstance(key_mapping, dict) or core_network is None: + raise ValueError("Mapping the params sharded metas failed.") + except NotImplementedError as e: + raise NotImplementedError( + f"Network '{type(network).__name__}' does not have the method to convert Hugging Face weights. " + "Please ensure the network or its backbone implements this method.") from e if not src_sharded_tensor_metas or not param_file_mappings: raise RuntimeError( - f"Failed to load valid metadata from checkpoint directory: {checkpoint_dir}. " + f"Failed to load valid metadata from checkpoint directory: `{checkpoint_dir}`. " "Metadata must include both sharded tensor information and parameter-file mappings." ) # Get current strategy metadata from network and optimizer logger.info(".........Get Current Strategy Metadata.........") - cur_rank_strategy_layout = get_current_strategy_metadata(network=network)[0] + cur_rank_strategy_layout = get_current_strategy_metadata(network=network) cur_rank_sharded_tensors: List[ShardedTensor] = [] if cur_rank_strategy_layout: # Convert strategy layout to required format - cur_rank_strategy_layout = [dict([item]) for item in cur_rank_strategy_layout.items()] + cur_rank_strategy_layout = [dict([item]) for item in cur_rank_strategy_layout[0].items()] # Define parameter filtering function def filter_fun(param_name: str) -> bool: @@ -888,26 +974,38 @@ def load_checkpoint( dst_sharded_tensor_metas = convert_sharded_tensor_list_to_dict(cur_rank_sharded_tensors) # Categorize parameters based on sharding strategies - _, no_shard_params, online_shard_params = categorize_params( - dst_sharded_tensor_metas, src_sharded_tensor_metas, param_file_mappings) + _, need_concat_params, no_shard_params, online_shard_params = categorize_params( + dst_sharded_tensor_metas, src_sharded_tensor_metas, param_file_mappings + ) - # Load parameters that don't require resharding + # Process Weight state_dict: Dict[str, Parameter] = {} + + # Concat parameters + concat_params(checkpoint_dir, core_network, key_mapping, need_concat_params, state_dict) + + # Load parameters that don't require resharding for file_name, param_info in no_shard_params.items(): param_name_list = param_info["param_name_list"] param_dtype_list = param_info["param_dtype_list"] - state_dict.update( - load_safetensor(os.path.join(checkpoint_dir, file_name), param_name_list, dtype=param_dtype_list) + no_reshard_state_dict = load_safetensor( + os.path.join(checkpoint_dir, file_name), param_name_list, dtype=param_dtype_list ) + state_dict.update({ + key_mapping[param_name]: value + for param_name, value in no_reshard_state_dict.items() + }) + # Load and reshard parameters that require online resharding for param_name, (full_shape, from_layout, to_layout, to_rank_id) in online_shard_params.items(): reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) all_offset = reshard_handler.infer_all_tensor_offset() from_tensor_map = load_tensor_by_offset( - all_offset, param_name, checkpoint_dir, src_sharded_tensor_metas, param_file_mappings + all_offset, param_name, checkpoint_dir, src_sharded_tensor_metas, param_file_mappings, key_mapping ) target_weight = reshard_handler.get_real_tensor(from_tensor_map) + param_name = key_mapping[param_name] state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) # Handle global_step for optimizer if needed @@ -915,8 +1013,7 @@ def load_checkpoint( # Initialize global_step with default or from common.json if not global_step: common_file = os.path.join(checkpoint_dir, 'common.json') - global_step = 0 if not os.path.exists(common_file) else \ - CommonInfo.load_common(common_file).global_step + global_step = 0 if not os.path.exists(common_file) else CommonInfo.load_common(common_file).global_step state_dict["global_step"] = Parameter( Tensor([global_step], mstype.int32), name="global_step", requires_grad=False @@ -926,6 +1023,55 @@ def load_checkpoint( load_parameters(network, state_dict, optimizer) +def concat_params(checkpoint_dir: str, core_network, key_mapping: dict, need_concat_params, state_dict: dict): + """Concat the need_concat_params dict in checkpoint.""" + for param_name, concat_info in need_concat_params.items(): + sharded_tensor_list, reshard_info = concat_info + org_weight_dict = {} + # Get all the params need to concat into `org_weight_dict`. + for sharded_tensor in sharded_tensor_list: + org_weight_dict.update( + load_safetensor( + checkpoint_path=os.path.join(checkpoint_dir, sharded_tensor['file_name']), + param_name=sharded_tensor['sub_name'], + dtype=sharded_tensor['param_dtype'] + ) + ) + # Mapping the weight key to MCore key into `concat_dict`. + concat_dict = { + key_mapping[k]: v + for k, v in org_weight_dict.items() + } + # Concat the weight. + concated_weight = core_network.convert_hf_weight(concat_dict) + + if reshard_info: + # Get the offset of the Tensor to reshard. + full_shape, from_layout, to_layout, to_rank_id = reshard_info + reshard_handler = ReshardHandler(param_name, full_shape, from_layout, to_layout, to_rank_id) + all_offset = reshard_handler.infer_all_tensor_offset() + # Get the slice to reshard the Tensor. + slices = tuple(slice(start, end) for start, end in all_offset[0]) + target_weight = concated_weight[param_name][slices] + # Update to `state_dict` to load into the network. + state_dict[param_name] = Parameter(target_weight, name=param_name, requires_grad=False) + else: + state_dict[param_name] = Parameter(concated_weight[param_name], name=param_name, requires_grad=False) + + +def check_the_param_for_load_ckpt(balanced_load: bool, checkpoint: str, network: Cell): + """Check the params passing in `load_checkpoint` method is legal.""" + if network is None: + raise ValueError("The 'network' cannot be None - a target network is required for loading.") + + if balanced_load: + raise ValueError("The balanced loading strategy is not supported yet. " + "`balanced_load` is a preset switch, please set `balanced_load` to False.") + + if not os.path.exists(checkpoint): + raise ValueError(f"Checkpoint does not exist: {checkpoint}") + + def load_parameters( network: Cell, state_dict: Dict[str, Parameter], @@ -956,7 +1102,7 @@ def load_parameters( optimizer_param_names = set(optimizer.parameters_dict().keys()) if optimizer else set() for param_name in list(state_dict.keys()): if param_name not in network_param_names and param_name in optimizer_param_names \ - and param_name not in state_dict_opt: + and param_name not in state_dict_opt: state_dict_opt[param_name] = state_dict.pop(param_name) # Load parameters into network @@ -1007,6 +1153,27 @@ def get_checkpoint_path(checkpoint: str) -> str: if not os.path.isdir(checkpoint): raise ValueError(f"Checkpoint path is not a directory: {checkpoint}") + # Check all need checkpoint files if load Hugging Face checkpoint + hf_index_json = os.path.join(checkpoint, "model.safetensors.index.json") + if os.path.exists(hf_index_json): + with open(hf_index_json, 'r', encoding='utf-8') as f: + index_json = json.load(f) + if isinstance(index_json, dict): + weight_map = index_json['weight_map'] if 'weight_map' in index_json else index_json + else: + raise ValueError(f"Format of '{hf_index_json}' is illegal!") + + sf_file_list = set(weight_map.values()) + not_exist_file = [ + f + for f in sf_file_list + if not os.path.isfile(os.path.join(checkpoint, f)) + ] + not_exist_file.sort() + if not_exist_file: + raise ValueError(f"The files '{not_exist_file}' do not exist in `{checkpoint}`.") + return checkpoint + tracker_filename = get_checkpoint_tracker_filename(checkpoint) if os.path.exists(tracker_filename): iteration = get_latest_iteration_from_tracker(checkpoint) diff --git a/mindformers/checkpoint/sharded_tensor.py b/mindformers/checkpoint/sharded_tensor.py index 9677149f97265938fda14133bcedd4c9aa588c66..041c4c01881efcba11886e35282634b272d7f596 100644 --- a/mindformers/checkpoint/sharded_tensor.py +++ b/mindformers/checkpoint/sharded_tensor.py @@ -18,12 +18,13 @@ from collections import defaultdict from dataclasses import dataclass from typing import List, Dict, Optional, Tuple, Union, Callable -from mindformers.tools.logger import logger - import mindspore as ms from mindspore.nn import Cell from mindspore.parallel.shard import _DistributedTensorInfo +from mindformers.tools.logger import logger + + ReplicaId = Union[int, Tuple[int, ...]] @@ -39,6 +40,12 @@ class ShardedTensor: key: str """Unique identifier of a global tensor.""" + org_key: str + """ + Record the original weight key name. + Mostly used in load Hugging Face weight with online resharding. + """ + dtype: ms.dtype """Tensor dtype.""" @@ -80,8 +87,9 @@ def build_sharded_tensor( ) -> ShardedTensor: """Creates and returns a ShardedTensor instance with the specified parameters.""" return ShardedTensor( - key=param_name, dtype=param_dtype, local_shape=tuple(local_shape), global_shape=tuple(global_shape), - global_offset=tuple(global_offset), axis_fragmentations=tuple(axis_fragmentations), replica_id=replica_id, + key=param_name, org_key=param_name, dtype=param_dtype, local_shape=tuple(local_shape), + global_shape=tuple(global_shape), global_offset=tuple(global_offset), + axis_fragmentations=tuple(axis_fragmentations), replica_id=replica_id, allow_shape_mismatch=allow_shape_mismatch, allow_to_save=allow_to_save, layout=layout ) @@ -342,7 +350,7 @@ def get_sharded_tensor_list_from_strategy_metadata(param_infos: List[Dict], cur_ if not param_infos: return None - cur_rank_sharded_tensor_list = list() + cur_rank_sharded_tensor_list = [] cur_param_name_list = get_param_name_from_layout(param_infos) cur_value_type_list = get_value_type_from_layout(param_infos) @@ -399,7 +407,7 @@ def get_sharded_tensor_list_from_cell( Returns: List of ShardedTensor objects with metadata from network and optimizer parameters """ - logger.info(f".........Get Current Strategy Metadata from Cell.........") + logger.info(".........Get Current Strategy Metadata from Cell.........") cur_rank_sharded_tensor_list: List[ShardedTensor] = [] def _get_sharded_tensors_from_cell( @@ -418,7 +426,7 @@ def get_sharded_tensor_list_from_cell( Returns: List of ShardedTensor objects for the cell's parameters """ - sharded_tensor_list = list() + sharded_tensor_list = [] for param in cell.get_parameters(): param_name = param.name diff --git a/mindformers/models/qwen3/modeling_qwen3_train.py b/mindformers/models/qwen3/modeling_qwen3_train.py index 98d83b0925a7746fb63e1f526fb3b702c0281b0d..1257ad50f277ea79884a13d6323e762be01ef0e3 100644 --- a/mindformers/models/qwen3/modeling_qwen3_train.py +++ b/mindformers/models/qwen3/modeling_qwen3_train.py @@ -206,3 +206,62 @@ class TrainingQwen3ForCausalLM(TrainModelMixin, Qwen3PreTrainedModel): ms_name_map_dict.update({w_gate_hidden_key: w_gate_hidden_value}) return ms_name_map_dict + + def convert_hf_weight(self, need_concat_dict): + """ + Convert Hugging Face weight to MindSpore Transformers weight. + + For Qwen3 Model, it is necessary to concatenate `q_proj/k_proj/v_proj` weight to `linear_qkv`, + and concatenate `gate_proj/up_proj` weight to `linear_fc1`. + + Args: + need_concat_dict: The weight dict contains each group weight to concatenate. + + Returns: + A dict contains the concatenated weight. + """ + transformer_config = self.get_gpt_transformer_config() + + use_contiguous_weight_layout_attention = transformer_config.use_contiguous_weight_layout_attention + use_interleaved_weight_layout_mlp = transformer_config.use_interleaved_weight_layout_mlp + + ms_weight_dict = {} + + head_dim = transformer_config.kv_channels + n_kv_heads = transformer_config.num_query_groups + num_attention_heads = transformer_config.num_attention_heads + + ffn_hidden_size = transformer_config.ffn_hidden_size + + for name, value in need_concat_dict.items(): + part = name.split('.') + # Get Q/K/V Keys + if part[-2] == 'linear_q': + k_name = name.replace('linear_q', 'linear_k') + v_name = name.replace('linear_q', 'linear_v') + k_value = need_concat_dict.get(k_name) + v_value = need_concat_dict.get(v_name) + if use_contiguous_weight_layout_attention: + ms_weight_dict.update( + self.concat_qkv_contiguous(value, k_value, v_value, name) + ) + else: + ms_weight_dict.update( + self.concat_qkv_interleaved( + value, k_value, v_value, name, head_dim, n_kv_heads, num_attention_heads + ) + ) + # Get gate/up Keys + elif part[-2] == 'gating': + up_name = name.replace('gating', 'hidden') + up_value = need_concat_dict.get(up_name) + if use_interleaved_weight_layout_mlp: + ms_weight_dict.update( + self.concat_linear_fc1_interleaved(value, up_value, name, ffn_hidden_size) + ) + else: + ms_weight_dict.update( + self.concat_linear_fc1_contiguous(value, up_value, name) + ) + + return ms_weight_dict diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 7a0bff6982a09fc978a585ce784ddc85623aca18..08861a4224edb4fa7a04943f1c5b2e54d4c3538a 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -28,9 +28,35 @@ class ModelMixin: A few utilities for `mindspore.nn.Cell`, to be used as a mixin. """ + concat_mapping = [ + ('.linear_q.', '.linear_qkv.'), + ('.linear_k.', '.linear_qkv.'), + ('.linear_v.', '.linear_qkv.'), + ('.mlp.gating.', '.mlp.linear_fc1.'), + ('.mlp.hidden.', '.mlp.linear_fc1.'), + ] + def __init__(self): self.transformer_config = None + def convert_concat_name(self, weight_name): + r""" + convert HuggingFace weight name to MindFormers weight name. + + Args: + weight_name: huggingface weight names. + + Returns: + weight_name: converted weight names. + + """ + if is_legacy_model(): + raise RuntimeError(f"{self.__class__.__name__} does not implemented convert_name method.") + for split_name, concat_name in self.concat_mapping: + if split_name in weight_name: + weight_name = weight_name.replace(split_name, concat_name) + return weight_name + def convert_name(self, weight_name): r""" convert HuggingFace weight name to MindFormers weight name. @@ -129,6 +155,73 @@ class TrainModelMixin: is_train_model = True + def concat_qkv_contiguous(self, q_value, k_value, v_value, q_name): + """ + Concat the Q/K/V weight in contiguous format: + [Q_weights, K_weights, V_weights]. + """ + qkv_name = q_name.replace('linear_q', 'linear_qkv') + qkv_value = np.concatenate((q_value, k_value, v_value), 0) + + # return converted qkv weight + return {qkv_name: qkv_value} + + def concat_qkv_interleaved(self, q_value, k_value, v_value, q_name, + head_dim, n_kv_heads, num_attention_heads): + """ + Concat the Q/K/V weight in interleaved format: + [Q_head0, K_head0, V_head0, Q_head1, ...]. + """ + n_rep = num_attention_heads // n_kv_heads + + # Start to concat qkv weight + qkv_name = q_name.replace('linear_q', 'linear_qkv') + + # Reshape the q/k/v weight to 3d + q_reshape = q_value.reshape(n_kv_heads, n_rep * head_dim, -1) + k_reshape = k_value.reshape(n_kv_heads, head_dim, -1) + v_reshape = v_value.reshape(n_kv_heads, head_dim, -1) + + # Then concat them with column (axis 1) + concat_qkv_weight = np.concatenate((q_reshape, k_reshape, v_reshape), axis=1) + + # Reshape the concated qkv weight to 2d + qkv_value = concat_qkv_weight.reshape((n_rep + 2) * head_dim * n_kv_heads, -1) + + # return converted qkv weight + return {qkv_name: qkv_value} + + def concat_linear_fc1_contiguous(self, gate_value, up_value, gate_name): + """ + Concat the gate/up weight in contiguous format: + [Gate_weights, Hidden_weights]. + """ + linear_fc1_key = gate_name.replace('gating', 'linear_fc1') + linear_fc1_value = np.concatenate((gate_value, up_value), 0) + + # return converted ffn weight + return {linear_fc1_key: linear_fc1_value} + + def concat_linear_fc1_interleaved(self, gate_value, up_value, gate_name, ffn_hidden_size): + """ + Concat the gate/up weight in interleaved format: + [Gate_weights[0], Hidden_weights[0], Gate_weights[1], Hidden_weights[1], ...]. + """ + linear_fc1_key = gate_name.replace('gating', 'linear_fc1') + + # Reshape gate/up to 3d + gate_reshape = gate_value.reshape(ffn_hidden_size, 1, -1) + hidden_reshape = up_value.reshape(ffn_hidden_size, 1, -1) + + # Concat gate and up + linear_fc1_value = np.concatenate((gate_reshape, hidden_reshape), axis=1) + + # Reshape the concated linear_fc1 weight to 2d + linear_fc1_value = linear_fc1_value.reshape(ffn_hidden_size * 2, -1) + + # return converted ffn weight + return {linear_fc1_key: linear_fc1_value} + def concat_qkv_weight_megatron(self, wq_keys, wk_keys, wv_keys, qkv_weight_dict, condition, ms_weight_dict, head_dim, n_kv_heads, num_attention_heads): """ diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 4aff24391387cd325f859c0d973a363db3ea5a18..eae453e4a2e598dcfb36491dca6ce168b62ba0f0 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -1222,7 +1222,7 @@ class BaseTrainer: for callback in self.config.callbacks: if "type" in callback and callback["type"] == "CheckpointMonitor": save_checkpoint_with_legacy_format = callback.get("use_legacy_format", True) - if not save_checkpoint_with_legacy_format: + if not save_checkpoint_with_legacy_format or not config.ckpt_use_legacy_format: try: from mindspore.parallel.strategy import enable_save_strategy_online enable_save_strategy_online() @@ -1268,7 +1268,8 @@ class BaseTrainer: is_mtp_model = network.is_mtp_model() transformer_config = network.get_gpt_transformer_config() - config.load_checkpoint = get_load_path_after_hf_convert(config, network) + if config.ckpt_use_legacy_format: + config.load_checkpoint = get_load_path_after_hf_convert(config, network) self._check_training_network_no_use_past(network) eval_network = None