diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index 348846945acd1143639997ff7f62c891d91ec05f..ac7f6f5937fa11007936ceb369a442aeeea94feb 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -1070,7 +1070,7 @@ class TrainingStateMonitor(Callback): if 'log' in self.max_attention_logit_format: self._output(tag, v.tolist(), step, ['log']) if 'tensorboard' in self.max_attention_logit_format: - tp_id = get_rank() // self.tensor_model_parallel_size + tp_id = get_rank() // self.tensor_model_parallel_size head_start = tp_id * len(v) data = {f"head_{head_start+i}": max_attention_logit for i, max_attention_logit in enumerate(v)} self._output(tag, data, step, ['tensorboard']) @@ -2531,6 +2531,7 @@ class MaxLogitsMonitor(Callback): network = network._backbone self._reset_max_attention_logit(network) + @MindFormerRegister.register(MindFormerModuleType.CALLBACK) class TopkBiasBalanceCallback(Callback): """ diff --git a/mindformers/core/config_args.py b/mindformers/core/config_args.py index 84b6d88463768461cb54c3de466ba1933492600d..82652d5570bdabafea8ba1e8664f5adb252e9ab7 100644 --- a/mindformers/core/config_args.py +++ b/mindformers/core/config_args.py @@ -506,8 +506,7 @@ class MFContextConfig(BaseArgsConfig): 'profile_level', 'profile_start_step', 'profile_stop_step', - 'pretrained_model_dir', - 'monitor_max_attention_logit' + 'pretrained_model_dir' ] def __init__( diff --git a/mindformers/core/optim/__init__.py b/mindformers/core/optim/__init__.py index 26fe52ff207cb4ce00959ee6743f0bc7de38fce6..74b46a495b3cd9696545eb9ac38700c840e04e0f 100644 --- a/mindformers/core/optim/__init__.py +++ b/mindformers/core/optim/__init__.py @@ -21,8 +21,9 @@ from .adamw import AdamW as BasicAdamW from .fused_adamw import FusedAdamW from .pma_adamw import PmaAdamW as BasicPmaAdamW from .fused_pma_adamw import FusedPmaAdamW +from .muon import Muon -__all__ = ['AdamW', 'PmaAdamW'] +__all__ = ['AdamW', 'PmaAdamW', 'Muon'] @MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) diff --git a/mindformers/core/optim/muon.py b/mindformers/core/optim/muon.py new file mode 100644 index 0000000000000000000000000000000000000000..b038d1eaf8c77988615b706c10f4b30ddbd3253f --- /dev/null +++ b/mindformers/core/optim/muon.py @@ -0,0 +1,527 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Muon API""" + +from __future__ import absolute_import + +import hashlib + +import numpy as np +from mindspore.common import dtype as mstype +from mindspore.ops import functional as F, composite as C, operations as P +from mindspore.common.api import jit +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple +from mindspore.communication.management import create_group, get_rank +from mindspore.ops.auto_generate import Chunk +from mindspore import get_auto_parallel_context + +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.core.context import is_legacy_model +from mindformers.tools.logger import logger + +_muon_opt = C.MultitypeFuncGraph("muon_opt") + + +def _perform_allgather_op(ns_inputs_item, op, tp, tp_dim, op_group, tp_group, param_name): + """Perform AllGather operations based on op and tp settings.""" + if "mlp.experts.weight" not in param_name: + # all gather op_shard + if op > 1: + ns_inputs_item = P.AllGather(group=op_group)(ns_inputs_item) + + # all gather tp_shard + if tp > 1: + if tp_dim == 0: + ns_inputs_item = P.AllGather(group=tp_group)(ns_inputs_item) + elif tp_dim == 1: + ns_inputs_item = P.AllGather(group=tp_group)(ns_inputs_item.T) + ns_inputs_item = ns_inputs_item.T + return ns_inputs_item + + +def zeropower_via_newtonschulz5_2d(x, dim_a, dim_b): + """Apply Newton-Schulz iteration for 2D tensors.""" + a, b, c = (3.4445, -4.7750, 2.0315) + + if dim_a > dim_b: + x = x.T + # Ensure spectral norm is at most 1 + x = x / (x.norm() + 1e-7) + # Perform the NS iterations + for _ in range(5): + a_mat = x @ x.T + b_mat = b * a_mat + c * a_mat @ a_mat + x = a * x + b_mat @ x + if dim_a > dim_b: + x = x.T + return x + + +def zeropower_via_newtonschulz5_3d(x, dim_a, dim_b): + """Apply Newton-Schulz iteration for 3D tensors.""" + a, b, c = (3.4445, -4.7750, 2.0315) + + if dim_a > dim_b: + x = P.Transpose()(x, (0, 2, 1)) + # Ensure spectral norm is at most 1 + x = x / P.ExpandDims()(P.ExpandDims()((x.norm(dim=(1, 2)) + 1e-7), 1), 1) + # Perform the NS iterations + for _ in range(5): + a_mat = P.BatchMatMul(transpose_b=True)(x, x) + b_mat = b * a_mat + c * P.BatchMatMul()(a_mat, a_mat) + x = a * x + P.BatchMatMul()(b_mat, x) + if dim_a > dim_b: + x = P.Transpose()(x, (0, 2, 1)) + return x + + +def _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group): + """Slice tensor to tp_shard and op_shard.""" + # slice X to tp_shard and slice X to op_shard + if tp > 1: + if tp_dim >= 0: + chunk_id = rank_id % tp + x = Chunk()(x, tp, tp_dim)[chunk_id] + + if op > 1: + if op_group == tp_group: + chunk_id = rank_id % tp + else: + chunk_id = rank_id // tp % op + x = Chunk()(x, op)[chunk_id] + return x + + +def _apply_muon_update( + gradient, muon_m, momentum, use_nesterov, param, lr, weight_decay, + matched_adamw_rms, muon_split_fn, muon_merge_fn, param_name, + op, tp, tp_dim, rank_id, op_group, tp_group): + """Apply Muon optimizer update.""" + op_sqrt = P.Sqrt() + op_cast = P.Cast() + op_reshape = P.Reshape() + op_shape = P.Shape() + + m_fp32 = op_cast(muon_m, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) + next_m = m_fp32 * momentum + gradient_fp32 + + if use_nesterov: + gradient_fp32 = gradient_fp32 + next_m * momentum + else: + gradient_fp32 = next_m + + ns_inputs = op_cast(gradient_fp32, mstype.bfloat16) + ns_inputs_list = muon_split_fn(param_name, ns_inputs) + x_list = [] + + dim_a, dim_b = None, None + for ns_inputs_item in ns_inputs_list: + dim_a, dim_b = op_shape(ns_inputs_item)[-2:] + + if len(op_shape(ns_inputs_item)) == 2: + ns_inputs_item = _perform_allgather_op( + ns_inputs_item, op, tp, tp_dim, op_group, tp_group, param_name) + x = zeropower_via_newtonschulz5_2d(ns_inputs_item, dim_a, dim_b) + x = _slice_tensor_to_shards(x, tp, tp_dim, op, rank_id, op_group, tp_group) + else: + x = zeropower_via_newtonschulz5_3d(ns_inputs_item, dim_a, dim_b) + + x_list.append(x) + + x_ret = muon_merge_fn(param_name, x_list) + param_fp32 = op_cast(param, mstype.float32) + param_fp32 = param_fp32 * (1 - lr * weight_decay) + + adjusted_ratio = op_sqrt(op_cast(max(dim_a, dim_b), mstype.float32)) * matched_adamw_rms + adjusted_lr = lr * adjusted_ratio + update_with_lr = adjusted_lr * x_ret + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) + next_param = F.depend(next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) + next_param = F.depend(next_param, F.assign(muon_m, op_cast(next_m, F.dtype(muon_m)))) + return op_cast(next_param, F.dtype(param)) + + +def _apply_adamw_update(param, exp_avg, exp_avg_sq, gradient, beta1, beta2, step, eps, lr, weight_decay): + """Apply AdamW optimizer update.""" + op_mul = P.Mul() + op_pow = P.Pow() + op_sqrt = P.Sqrt() + op_cast = P.Cast() + addcmul = P.Addcmul() + + param_fp32 = op_cast(param, mstype.float32) + next_param = op_mul(param_fp32, 1 - lr * weight_decay) + gradient_fp32 = op_cast(gradient, mstype.float32) + + next_param = F.depend( + next_param, + F.assign( + exp_avg, + op_mul(exp_avg, beta1) + + op_mul(gradient_fp32, op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1), + ), + ) + next_param = F.depend( + next_param, + F.assign( + exp_avg_sq, + addcmul( + op_mul(exp_avg_sq, beta2), + gradient_fp32, + gradient_fp32, + op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, + ), + ), + ) + + bias_correction1 = 1 - op_pow(op_cast(beta1, mstype.float32), step) + bias_correction2 = 1 - op_pow(op_cast(beta2, mstype.float32), step) + step_size = lr / bias_correction1 + denom = op_sqrt(exp_avg_sq / bias_correction2) + eps + return_param = next_param - op_mul(exp_avg / denom, step_size) + F.assign(param, op_cast(return_param, F.dtype(param))) + return op_cast(return_param, F.dtype(param)) + + +@_muon_opt.register( + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", + "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Number", "Number", "Bool", + "Bool", "Bool", "String", "String", "String", "Function", "Function") +def _update_run_op( + momentum, matched_adamw_rms, beta1, beta2, step, eps, lr, weight_decay, rank_id, + param, exp_avg, exp_avg_sq, gradient, muon_m, tp, op, tp_dim, use_muon, + use_nesterov, optim_filter, op_group, tp_group, param_name, muon_split_fn, muon_merge_fn): + """ + Update parameters. + + Args: + beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). + beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). + eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. + lr (Tensor): Learning rate. + weight_decay (numbers.Number): Weight decay. Should be equal to or greater than 0. + param (Tensor): Parameters. + m (Tensor): m value of parameters. + v (Tensor): v value of parameters. + gradient (Tensor): Gradient of parameters. + decay_flag (bool): Applies weight decay or not. + optim_filter (bool): Applies parameter update or not. + + Returns: + Tensor, the new value of v after updating. + """ + op_cast = P.Cast() + if "max_logits_val" in param_name: + return op_cast(gradient, F.dtype(param)) + + if not optim_filter: + return gradient + + if use_muon: + return _apply_muon_update( + gradient, muon_m, momentum, use_nesterov, param, lr, weight_decay, + matched_adamw_rms, muon_split_fn, muon_merge_fn, param_name, + op, tp, tp_dim, rank_id, op_group, tp_group) + + return _apply_adamw_update(param, exp_avg, exp_avg_sq, gradient, beta1, beta2, step, eps, lr, weight_decay) + + +@MindFormerRegister.register(MindFormerModuleType.OPTIMIZER) +class Muon(Optimizer): + """ + Muon optimizer implementation. + + Args: + params: model parameters to optimize. + learning_rate (float): Learning rate. Default: ``2e-2``. + weight_decay (float): Weight decay factor. Default: ``0.1``. + matched_adamw_rms (float): RMS matching parameter for AdamW. Default: ``0.2``. + momentum (float): Momentum factor. Default: ``0.95``. + nesterov (bool): Whether to use Nesterov momentum. Default: ``True``. + ns_steps (int): Number of Newton-Schulz steps. Default: ``5``. + adamw_betas (tuple): Beta parameters for AdamW. Default: ``(0.95, 0.95)``. + adamw_eps (float): Epsilon for AdamW. Default: ``1e-8``. + micro_batch_num (int): Number of micro batches. Default: ``1``. + qk_clip_threshold (float): QK clip threshold. Default: ``4``. + model: The model model. Default: ``None``. + """ + + def __init__( + self, + params, + learning_rate=2e-2, + weight_decay=0.1, + matched_adamw_rms=0.2, + momentum=0.95, + nesterov=True, + ns_steps=5, + adamw_betas=(0.95, 0.95), + adamw_eps=1e-8, + micro_batch_num=1, + qk_clip_threshold=4, + model=None, + ): + super().__init__(learning_rate, params, weight_decay) + + self._verify_model(model) + + # Initialize basic parameters + self._initialize_basic_params(adamw_betas, adamw_eps, momentum, matched_adamw_rms, nesterov) + + # Initialize model configuration + self._initialize_network_config(model) + + # Initialize parameter layers + self._initialize_param_layers(model) + + # Initialize QK-clip parameters + self.ones = Tensor([1.0], mstype.float32) + self.rank_id = get_rank() + self.rank_ids = tuple(self.rank_id for _ in self._parameters) + self.logit_threshold = Tensor([qk_clip_threshold * micro_batch_num], dtype=mstype.float32) + + # Initialize Muon momentum + self._initialize_muon_moments(model) + + # Initialize tensor parallel dimensions + self._initialize_tp_dims(model) + + # Initialize AdamW moments + self._initialize_adamw_moments(model) + + # Initialize parallel configuration + self._initialize_parallel_config(model) + + # Initialize communication groups + self._initialize_communication_groups() + + # Initialize optimizer parallel groups + self._initialize_op_groups(model) + + # Store model for QK-clip + self.model = model + self.ns_steps = ns_steps + + def _verify_model(self, model): + """Verify if the model is compatible with Muon optimizer.""" + if model is None: + raise ValueError("Model must be provided for Muon optimizer.") + + if is_legacy_model(): + raise ValueError("Muon does not support Legacy Model.") + + config = model.get_gpt_transformer_config() + + if not config.multi_latent_attention: + raise ValueError("Current Muon implementation only supports models with Multi-Latent Attention enabled.") + + def _initialize_basic_params(self, adamw_betas, adamw_eps, momentum, matched_adamw_rms, nesterov): + """Initialize basic optimizer parameters.""" + self.beta1 = Tensor(np.array([adamw_betas[0]]).astype(np.float32)) + self.beta2 = Tensor(np.array([adamw_betas[1]]).astype(np.float32)) + self.eps = Tensor(np.array([adamw_eps]).astype(np.float32)) + self.muon_momentum = Tensor(np.array([momentum]).astype(np.float32)) + self.matched_adamw_rms = Tensor(np.array([matched_adamw_rms]).astype(np.float32)) + self.use_nesterov = tuple(nesterov for _ in self._parameters) + self.param_name_tuple = tuple(p.name for p in self._parameters) + + def _initialize_network_config(self, model): + """Initialize Model configuration and split/merge functions.""" + + self.muon_split_fn, self.muon_merge_fn = model.make_model_muon_fns() + self.muon_split_fns = tuple(self.muon_split_fn for _ in self._parameters) + self.muon_merge_fns = tuple(self.muon_merge_fn for _ in self._parameters) + + def _initialize_param_layers(self, model): + """Initialize parameter layer indices.""" + self.param_layer = model.get_param_layer_indices(self._parameters) + + def _initialize_muon_moments(self, model): + """Initialize Muon momentum parameters.""" + muon_filter = model.get_muon_filter() + + self.muon_m = [] + self.param_idx_in_opt = {} + for idx, param in enumerate(self._parameters): + self.param_idx_in_opt[param.name] = idx + + for param in self._parameters: + if muon_filter(param): + x1 = param.clone("zeros") + x1.name = "muon_m" + "." + x1.name + self.muon_m.append(x1) + logger.info(f"Muon apply: {param}") + else: + self.muon_m.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="muon_m." + param.name)) + self.muon_m = ParameterTuple(self.muon_m) + self.use_muon = tuple(muon_filter(param) for param in self._parameters) + + def _initialize_tp_dims(self, model): + """Initialize tensor parallel dimensions.""" + self.tp_dims = model.get_tp_dims(self._parameters) + + def _initialize_adamw_moments(self, model): + """Initialize AdamW momentum parameters.""" + muon_filter = model.get_muon_filter() + + self.moments1 = [] + self.moments2 = [] + for param in self._parameters: + if not muon_filter(param): + x1 = param.clone("zeros") + x1.name = "adam_m" + "." + x1.name + self.moments1.append(x1) + x2 = param.clone("zeros") + x2.name = "adam_v" + "." + x2.name + self.moments2.append(x2) + logger.info(f"Adam apply: {param}") + else: + self.moments1.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="adam_m." + param.name)) + self.moments2.append(Parameter(Tensor(np.array([0]).astype(np.float32)), name="adam_v." + param.name)) + self.moments1 = ParameterTuple(self.moments1) + self.moments2 = ParameterTuple(self.moments2) + + def _initialize_parallel_config(self, model): + """Initialize parallel configuration.""" + self.tp = model.get_gpt_transformer_config().tensor_model_parallel_size + self.tps = tuple(self.tp for _ in self._parameters) + logger.info(f"Muon tp group size is: {self.tp}") + + if not get_auto_parallel_context('enable_parallel_optimizer'): + self.op = 1 + else: + self.op = get_auto_parallel_context('optimizer_weight_shard_size') + if self.op == -1: + raise ValueError("Must set parallel.parallel_optimizer_config.optimizer_weight_shard_size when using Muon") + logger.info(f"Muon op group size is: {self.op}") + + def _initialize_communication_groups(self): + """Initialize communication groups for parallel training.""" + self.tp_group = self._get_tp_group_name() + self.op_group = self._get_op_group_name() + self.tp_groups = tuple(self.tp_group for _ in self._parameters) + + def _initialize_op_groups(self, model): + """Initialize optimizer parallel groups for parameters.""" + self.ops, self.op_groups = model.get_op_groups_info( + self._parameters, self.op, self.tp_group, self.op_group + ) + + def _create_communication_group(self, rank_list, group_type): + """ + Create a communication group with a hashed name. + + Args: + rank_list: List of ranks in the communication group + group_type: Type of group for logging (e.g., "op", "tp") + + Returns: + str: The created group name + """ + rank_list_str = "-".join([str(i) for i in rank_list]) + logger.info(f"Muon {group_type} group list is: {rank_list_str}") + hashed = hashlib.md5(rank_list_str.encode()).hexdigest()[:48] + group_name = str(hashed) + create_group(group_name, rank_list) + return group_name + + def _get_op_group_name(self): + """ + Generates a unique group name for optimizer parallel communication group. + + Returns: + str: The optimizer parallel group name + """ + dp_range = self.tp + op_range = self.tp * self.op + rank_start = self.rank_id % dp_range + self.rank_id // op_range * op_range + rand_end = rank_start + op_range + rank_list = list(range(rank_start, rand_end, dp_range)) + return self._create_communication_group(rank_list, "op") + + def _get_tp_group_name(self): + """ + Generates a unique group name for tensor parallel communication group. + + Returns: + str: The tensor parallel group name + """ + rank_start = self.rank_id // self.tp * self.tp + rand_end = self.rank_id // self.tp * self.tp + self.tp + rank_list = list(range(rank_start, rand_end)) + return self._create_communication_group(rank_list, "tp") + + @jit(backend="ms_backend") + def construct(self, gradients): + """Construct method for optimizer. + + Args: + gradients: Gradients for optimization. + + Returns: + Updated gradients after optimization. + """ + gradients = self.flatten_gradients(gradients) + weight_decay = self.get_weight_decay() + lr = self.get_lr() + self.assignadd(self.global_step, self.global_step_increase_tensor) + optim_result = self.hyper_map( + F.partial( + _muon_opt, + self.muon_momentum, + self.matched_adamw_rms, + self.beta1, + self.beta2, + self.global_step, + self.eps, + lr, + ), + weight_decay, + self.rank_ids, + self._parameters, + self.moments1, + self.moments2, + gradients, + self.muon_m, + self.tps, + self.ops, + self.tp_dims, + self.use_muon, + self.use_nesterov, + self.optim_filter, + self.op_groups, + self.tp_groups, + self.param_name_tuple, + self.muon_split_fns, + self.muon_merge_fns, + ) + + updates = self.model.apply_qk_clip_scaling( + self._parameters, + self.param_name_tuple, + self.param_layer, + self.logit_threshold, + self.muon_split_fn, + self.muon_merge_fn, + ) + + # Apply the weight updates + for param_idx, weights in updates: + optim_result = F.depend(optim_result, F.assign(self._parameters[param_idx], weights)) + + return optim_result diff --git a/mindformers/core/optim/muon_utils.py b/mindformers/core/optim/muon_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e23dfbf7faf7ee9c7714bdbab5ec3c88d4a28ede --- /dev/null +++ b/mindformers/core/optim/muon_utils.py @@ -0,0 +1,322 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Muon utils""" + +import math +from fnmatch import fnmatch +from mindspore import ops as P +from mindspore.ops.operations import Morph +from mindspore import nn + + +class BlockSplitReshape(nn.Cell): + """ + Reshape tensor by splitting its last dimension into blocks. + + This operation takes a tensor and splits its last dimension into equal-sized blocks, + adding a new dimension for the block index. + + Args: + block: Block size for splitting the last dimension. + """ + + def __init__( + self, + block + ): + super().__init__() + self.block = block + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + *prefix, dim = args[0] + t = prefix + [dim // self.block, self.block] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class TensorReshapeTo3D(nn.Cell): + """ + Reshape tensor to 3D with specified middle and last dimensions. + + This operation reshapes a tensor to a 3-dimensional tensor where the first dimension + is automatically calculated from the total size, and the last two dimensions are fixed. + + Args: + dim1: The second dimension (middle dimension) of the output 3D tensor. + dim2: The third dimension (last dimension) of the output 3D tensor. + """ + + def __init__( + self, + dim1, + dim2, + ): + super().__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + t = [total // (self.dim1 * self.dim2), self.dim1, self.dim2] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class PrefixDimensionReshape(nn.Cell): + """ + Reshape tensor with fixed prefix dimensions and calculated last dimension. + + This operation reshapes a tensor by specifying the leading (prefix) dimensions, + while the last dimension is automatically calculated from the total size. + + Args: + *prefix: Variable number of prefix dimensions for the output tensor shape. + """ + + def __init__( + self, + *prefix + ): + self.prefix = list(prefix) + super().__init__() + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + prefix_total = math.prod(self.prefix) + t = self.prefix + [total // prefix_total] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +class TensorReshapeTo2D(nn.Cell): + """ + Reshape tensor to 2D with specified last dimension. + + This operation flattens a tensor to a 2-dimensional tensor where the last dimension + is fixed and the first dimension is automatically calculated from the total size. + + Args: + dim: The second dimension (last dimension) of the output 2D tensor. + """ + + def __init__( + self, + dim + ): + self.dim = dim + super().__init__() + self.local_reshape = Morph(self.reshape_fn, + self.reshape_infer_shape, + self.reshape_infer_dtype + ) + + def reshape_fn(self, x, shp): + """Reshape function.""" + return P.Reshape()(x, shp) + + def reshape_infer_shape(self, *args): + tensor_shape = args[0] + total = math.prod(tensor_shape) + t = [total // self.dim, self.dim] + return t + + def reshape_infer_dtype(self, *args): + return args[0] + + def construct(self, tensor, shp): + return self.local_reshape(tensor, shp) + + +def muon_split(tensor, part_a: int, part_b: int, num_blocks: int): + """ + Split a 2D tensor into two periodic parts along its first dimension. + The split pattern repeats every (part_a + part_b) elements. + + Args: + tensor: Input tensor of shape (M, N). + part_a: Number of elements in the first part of each block. + part_b: Number of elements in the second part of each block. + num_blocks: Total number of (part_a + part_b) blocks. + + Returns: + A tuple of two tensors (first_part, second_part), + where: + - first_part contains all part_a segments of each block. + - second_part contains all part_b segments of each block. + """ + tensor = tensor.T + *prefix, _ = tensor.shape + block = part_a + part_b + t = BlockSplitReshape(block)(tensor, (*prefix, -1, block)) + + first_part = PrefixDimensionReshape(*prefix)(t[..., :part_a], (*prefix, -1)).T + second_part = PrefixDimensionReshape(*prefix)(t[..., part_a:], (*prefix, -1)).T + return first_part, second_part + + +def muon_merge(tensor_a, tensor_b, part_a: int, part_b: int, num_blocks: int): + """ + Merge two tensors back into the original periodic layout + that was split by muon_split(). + + Args: + tensor_a: Tensor containing the first part of each block. + tensor_b: Tensor containing the second part of each block. + part_a: Number of elements in the first part of each block. + part_b: Number of elements in the second part of each block. + num_blocks: Total number of (part_a + part_b) blocks. + + Returns: + A single tensor of the same shape as before muon_split(). + """ + tensor_a = tensor_a.T + tensor_b = tensor_b.T + *prefix, _ = tensor_a.shape + + a = BlockSplitReshape(part_a)(tensor_a, (*prefix, -1, part_a)) + b = BlockSplitReshape(part_b)(tensor_b, (*prefix, -1, part_b)) + t = P.Concat(axis=-1)([a, b]) + out = PrefixDimensionReshape(*prefix)(t, (*prefix, -1)).T + return out + + +def _eval_tuple(spec, name, tensor): + return spec(name, tensor) if callable(spec) else spec + + +def make_muon_fns(schema): + """ + Generate two generic functions: + - split_one(param_name, tensor) -> List[tensor] + - merge_one(param_name, parts_list) -> tensor + + Dimensions in schema should be either numbers or callback functions: + - periodic: rule["parts"] = (a, b, num_blocks) or lambda(name, tensor)->(a,b,blocks) + - reshape_* : rule["reshape"] = (x, y, z) or lambda(name, tensor)->(x,y,z) + """ + + def split_fn(param_name, tensor): + """ + Input a 2D tensor, split it according to schema rules, and return several segments (List[tensor]). + """ + + for rule in schema: + if not any(fnmatch(param_name, pat) for pat in rule["patterns"]): + continue + + kind = rule["kind"] + + if kind == "periodic": + part_a, part_b, num_blocks = _eval_tuple(rule["parts"], param_name, tensor) + first_part, second_part = muon_split(tensor, part_a, part_b, num_blocks) + return [first_part, second_part] + + if kind == "reshape_concat": + # e.g. experts.weight1: first reshape to [E, H, 2I], then split into two halves along the last dimension + _, hidden_size, total_intermediate = _eval_tuple(rule["reshape"], param_name, tensor) + half_intermediate = total_intermediate // 2 + t3 = TensorReshapeTo3D(hidden_size, total_intermediate)(tensor, (-1, hidden_size, total_intermediate)) + return [t3[..., :half_intermediate], t3[..., half_intermediate:]] + + if kind == "reshape_only": + # e.g. experts.weight2: just reshape to [E, I, H], no split + _, intermediate_size, hidden_size = _eval_tuple(rule["reshape"], param_name, tensor) + return [TensorReshapeTo3D(intermediate_size, hidden_size)(tensor, (-1, intermediate_size, hidden_size))] + + if kind == "alt_pair_periodic": + # Alternating rows 1,1 (blocks = M//2) + num_blocks = tensor.shape[0] // 2 + a, b = muon_split(tensor, 1, 1, num_blocks) + return [a, b] + + # Default: no processing, return as whole block + return [tensor] + + def merge_fn(param_name, parts_list): + """ + Merge the output of split_one (List[tensor]) back to 2D according to the same rules. + """ + concat = P.Concat(axis=-1) + + for rule in schema: + if not any(fnmatch(param_name, pat) for pat in rule["patterns"]): + continue + + kind = rule["kind"] + + if kind == "periodic": + part_a, part_b, num_blocks = _eval_tuple(rule["parts"], param_name, parts_list[0]) + # Convention: periodic always has two segments + return muon_merge(parts_list[0], parts_list[1], part_a, part_b, num_blocks) + + if kind == "reshape_concat": + _, hidden_size, total_intermediate = _eval_tuple(rule["reshape"], param_name, parts_list[0]) + cat = concat([parts_list[0], parts_list[1]]) # [..., I] + [..., I] -> [..., 2I] + return TensorReshapeTo2D(total_intermediate)(cat, (-1, total_intermediate)) + + if kind == "reshape_only": + _, _, hidden_size = _eval_tuple(rule["reshape"], param_name, parts_list[0]) + # Only one segment, directly restore to 2D + return TensorReshapeTo2D(hidden_size)(parts_list[0], (-1, hidden_size)) + + if kind == "alt_pair_periodic": + num_blocks = parts_list[0].shape[0] # 1 row per block + return muon_merge(parts_list[0], parts_list[1], 1, 1, num_blocks) + + # Default: directly take the first segment + return parts_list[0] + + return split_fn, merge_fn diff --git a/mindformers/parallel_core/mf_model_config.py b/mindformers/parallel_core/mf_model_config.py index 8c017104490a17a06a7da224ba9b75a005fa3cfc..dd7ba2557fdb62e6e3d8625ee98303b543209fa4 100644 --- a/mindformers/parallel_core/mf_model_config.py +++ b/mindformers/parallel_core/mf_model_config.py @@ -62,7 +62,7 @@ def convert_str_to_mstype(type_str) -> dtype: if not isinstance(type_str, str): raise TypeError(f"The type of 'type_str' must 'string', but got '{type(type_str)}'.") - if type_str in ms_dtype_mapping.keys(): + if type_str in ms_dtype_mapping: return ms_dtype_mapping[type_str] raise ValueError(f"The value of 'type_str' must be in {list(ms_dtype_mapping.keys())}, " @@ -268,6 +268,9 @@ class MFModelConfig: mask_func_type: str = "attn_mask_fill" """Mask function type to use for the attention layer.""" + monitor_max_attention_logit: bool = False + """Whether to monitor the maximum attention logit value during training.""" + #################################################### # MoE Configuration Items For MindSpore Transformers #################################################### diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index fd0145e955e64fe7009cf1ef4d3c99063756c321..c5df82ba4e545b47bf10e8bf2c955af425a44570 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -26,6 +26,7 @@ from mindspore.ops.operations import Morph from mindspore import Tensor, dtype, nn from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation +from mindspore import ops from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss from mindformers.parallel_core.training_graph.transformer.multi_token_prediction import MultiTokenPredictionBlock @@ -52,6 +53,7 @@ from mindformers.parallel_core.utils.init_method import init_method_normal from mindformers.tools.logger import logger from mindformers.models.utils import get_current_rank_stage, get_model_parameters from mindformers.version_control import get_lazy_inline as lazy_inline +from mindformers.core.optim.muon_utils import make_muon_fns def func_infer_dtype(*args): @@ -723,3 +725,271 @@ class GPTModel(nn.Cell): else: params.update(get_model_parameters(self)) return params + + def make_model_muon_fns(self,): + """Read values from TransformersConfig and generate schema.""" + + num_moe_experts = self.config.num_moe_experts + hidden_size = self.config.hidden_size + moe_ffn_hidden_size = self.config.moe_ffn_hidden_size + qk_head_dim = self.config.qk_head_dim + qk_pos_emb_head_dim = self.config.qk_pos_emb_head_dim + num_attention_heads = self.config.num_attention_heads + kv_lora_rank = self.config.kv_lora_rank + value_head_dim = self.config.v_head_dim + + schema = [ + # experts.weight1: reshape → split into two [num_moe_experts, hidden_size, moe_ffn_hidden_size] + { + "patterns": ["*mlp.experts.weight1*"], + "kind": "reshape_concat", + "reshape": (num_moe_experts, hidden_size, 2 * moe_ffn_hidden_size), + }, + # experts.weight2: reshape → [num_moe_experts, moe_ffn_hidden_size, hidden_size] + { + "patterns": ["*mlp.experts.weight2*"], + "kind": "reshape_only", + "reshape": (num_moe_experts, moe_ffn_hidden_size, hidden_size), + }, + # q_proj / q_up_proj: periodic split across heads + { + "patterns": [ + "*self_attention.linear_q_proj.weight*", + "*self_attention.linear_q_up_proj.weight*", + ], + "kind": "periodic", + "parts": (qk_head_dim, qk_pos_emb_head_dim, num_attention_heads), + }, + # kv_down_proj: one block + { + "patterns": ["*self_attention.linear_kv_down_proj.weight*"], + "kind": "periodic", + "parts": (kv_lora_rank, qk_pos_emb_head_dim, 1), + }, + # kv_up_proj: periodic split across heads + { + "patterns": ["*self_attention.linear_kv_up_proj.weight*"], + "kind": "periodic", + "parts": (qk_head_dim, value_head_dim, num_attention_heads), + }, + # fc1 and shared_fc1: alternating 1,1 split along rows + { + "patterns": [ + "*mlp.shared_experts.linear_fc1.weight*", + "*mlp.linear_fc1.weight*", + ], + "kind": "alt_pair_periodic", + }, + ] + + return make_muon_fns(schema) + + def get_muon_filter(self): + """Return a filter function to determine if a parameter should use Muon optimization. + + Returns: + A function that takes a parameter and returns True if it should use Muon. + """ + def muon_filter(param): + return ( + (len(param.shape) == 2 or len(param.shape) == 3) + and "word_embeddings" not in param.name + and "output_layer" not in param.name + ) + return muon_filter + + def get_tp_dims(self, params): + """Return tensor parallel dimensions for each parameter. + + Args: + params: List of parameters from the optimizer. + + Returns: + Tuple of TP dimensions for each parameter. + """ + no_tp_list = [ + "linear_q_down_proj", + "linear_kv_down_proj", + "shared_experts", + "mlp.router", + "hnorm.weight", "enorm.weight", "eh_proj.weight", + ] + + tp_dim_1_list = [ + "self_attention.linear_proj.weight", + "mlp.linear_fc2.weight" + ] + + def name_filter(param_name, full_name_list): + for full_name in full_name_list: + if full_name in param_name: + return True + return False + + tp_dims = [] + for param in params: + if name_filter(param.name, tp_dim_1_list): + tp_dims.append(1) + elif name_filter(param.name, no_tp_list): + tp_dims.append(-1) + else: + tp_dims.append(0) + return tuple(tp_dims) + + def get_op_groups_info(self, params, op, tp_group, op_group): + """Return optimizer parallel group information for each parameter. + + Args: + params: List of parameters from the optimizer. + op: Optimizer parallel size. + tp_group: Tensor parallel group name. + + Returns: + Tuple of (ops, op_groups) where: + - ops: tuple of op values for each parameter + - op_groups: tuple of group names for each parameter + """ + no_op_list = [ + "self_attention.linear_q_proj.weight", + "self_attention.linear_q_up_proj.weight", + "self_attention.linear_q_down_proj.weight", + "self_attention.linear_kv_up_proj.weight", + "self_attention.linear_kv_down_proj.weight", + "eh_proj" + ] + + use_tp_group_list = [ + "mlp.router.weight", + "mlp.shared_experts.linear_fc", + "self_attention.linear_q_down_proj.weight", + "eh_proj", + ] + + def name_filter(param_name, full_name_list): + for full_name in full_name_list: + if full_name in param_name: + return True + return False + + op_list = [] + op_groups = [] + + for param in params: + if name_filter(param.name, no_op_list): + op_list.append(1) + + op_groups.append("") + if param.parallel_optimizer: + param.parallel_optimizer = False + logger.warning( + f"Parameter {param.name}: parallel_optimizer was set to False due to the use of Muon optimizer." + ) + else: + op_list.append(op) + + if name_filter(param.name, use_tp_group_list): + op_groups.append(tp_group) + else: + op_groups.append(op_group) + + return tuple(op_list), tuple(op_groups) + + def get_param_layer_indices(self, params): + """Return layer indices for each parameter (used for QK-clip). + + Args: + params: List of parameters from the optimizer. + + Returns: + Tuple of layer indices for each parameter, where: + - layer_idx >= 0 stands for the layer_idx-th decoder layer + - layer_idx < 0 stands for the -(layer_idx+1)-th MTP layer + """ + param_layer = [] + for param in params: + name = param.name + try: + layer_idx = int(name.split(".")[2]) + except (ValueError, IndexError): + layer_idx = 0 + if name.startswith('mtp'): + layer_idx = -layer_idx - 1 + param_layer.append(layer_idx) + return tuple(param_layer) + + def apply_qk_clip_scaling(self, params, param_names, param_layer, logit_threshold, + muon_split_fn, muon_merge_fn): + """Apply QK-clip scaling to attention weight parameters. + + Args: + params: List of all parameters. + param_names: Tuple of parameter names. + param_layer: Tuple of layer indices for each parameter. + logit_threshold: Threshold for logit clipping. + muon_split_fn: Function to split parameters. + muon_merge_fn: Function to merge parameters. + + Returns: + List of (param_idx, scaled_weights) tuples to be updated. + """ + if not self.config.multi_latent_attention: + return [] + ones = ms.Tensor([1.0], dtype.float32) + qk_head_dim = self.config.qk_head_dim + qk_pos_emb_head_dim = self.config.qk_pos_emb_head_dim + + def get_scale_broadcast(scales, head_dim): + scale_broadcast = ops.tile(ops.expand_dims(scales, 1), (1, head_dim)).reshape(-1) + scale_broadcast = ops.expand_dims(scale_broadcast, 1) + return scale_broadcast + + # Build param name to index mapping + param_idx_in_opt = {name: idx for idx, name in enumerate(param_names)} + + updates = [] + for idx, param_name in enumerate(param_names): + if ( + "self_attention.linear_q_proj.weight" not in param_name + and "self_attention.linear_q_up_proj.weight" not in param_name + and "self_attention.linear_kv_up_proj.weight" not in param_name + ): + continue + + layer_idx = param_layer[idx] + param = params[idx] + + # Compute per-head scale factor + logit_threshold_f32 = ops.cast(logit_threshold, dtype=dtype.float32) + if layer_idx >= 0: + max_logits_name = (f"decoder.layers.{layer_idx}.self_attention." + "core_attention.max_logits_val") + else: + max_logits_name = (f"mtp.layers.{-(layer_idx+1)}.transformer_layer." + "self_attention.core_attention.max_logits_val") + + if max_logits_name not in param_idx_in_opt: + continue + + logits_row = params[param_idx_in_opt[max_logits_name]].reshape(-1) + mask = ops.greater_equal(logits_row, logit_threshold_f32) + safe_den = ops.where(mask, logits_row, ones) + scales = ops.where(mask, logit_threshold_f32 / safe_den, ones) + + weights = None + if ( + "self_attention.linear_q_proj.weight" in param_name + or "self_attention.linear_q_up_proj.weight" in param_name + ): + l2q_nope_proj, l2q_pe_proj = muon_split_fn(param_name, param) + l2q_nope_proj *= get_scale_broadcast(ops.sqrt(scales), qk_head_dim) + l2q_pe_proj *= get_scale_broadcast(scales, qk_pos_emb_head_dim) + weights = muon_merge_fn(param_name, [l2q_nope_proj, l2q_pe_proj]) + elif "self_attention.linear_kv_up_proj.weight" in param_name: + lkv2kv_k_nope, lkv2kv_v = muon_split_fn(param_name, param) + lkv2kv_k_nope *= get_scale_broadcast(ops.sqrt(scales), qk_head_dim) + weights = muon_merge_fn(param_name, [lkv2kv_k_nope, lkv2kv_v]) + + if weights is not None: + updates.append((idx, weights)) + + return updates diff --git a/mindformers/parallel_core/training_graph/transformer/flash_attention.py b/mindformers/parallel_core/training_graph/transformer/flash_attention.py index 026e88db69c69c7a1104330fb98fd5dc492c9ebc..8abc070dfcc18ac4e35842377e5f27dd913d7e25 100644 --- a/mindformers/parallel_core/training_graph/transformer/flash_attention.py +++ b/mindformers/parallel_core/training_graph/transformer/flash_attention.py @@ -32,7 +32,6 @@ from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagati from mindformers.parallel_core.transformer_config import MLATransformerConfig from mindformers.parallel_core.training_graph.transformer.enums import AttnMaskType from mindformers.parallel_core.training_graph.device_matrix import layout -from mindformers.core.context.build_context import Context, get_context class FlashAttention(Cell): @@ -161,10 +160,8 @@ class FlashAttention(Cell): self.reshape = aclnn_ops.Reshape() self.fa_out_transpose = aclnn_ops.Transpose() - if Context.is_exists(): - self.monitor_max_attention_logit = get_context("monitor_max_attention_logit") - else: - self.monitor_max_attention_logit = False + self.monitor_max_attention_logit = self.config.monitor_max_attention_logit + if self.monitor_max_attention_logit: self.max_logits_val = Parameter( Tensor(np.zeros((1, self.head_num)), dtype=mstype.float32), @@ -306,7 +303,7 @@ class FlashAttention(Cell): attention_mask, prefix) if self.monitor_max_attention_logit: - max_logits = ops.ReduceMax()(softmax_val, (2,3)) + max_logits = ops.ReduceMax()(softmax_val, (2, 3)) max_logits = ops.ReduceMax(keep_dims=True)(max_logits, (0)) output = F.depend(output, self.assign_add(self.max_logits_val, max_logits)) diff --git a/mindformers/parallel_core/transformer_config_utils.py b/mindformers/parallel_core/transformer_config_utils.py index 65504533750da18e52d3ca89224021edf6267362..0d762645b25a5bc9ce40e0ee5c0111cf844bd7cd 100644 --- a/mindformers/parallel_core/transformer_config_utils.py +++ b/mindformers/parallel_core/transformer_config_utils.py @@ -304,6 +304,7 @@ COMMON_CONFIG_MAPPING = { "hidden_act": "hidden_act", "mask_func_type": "mask_func_type", "param_init_std_rules": "param_init_std_rules", + "monitor_max_attention_logit": "monitor_max_attention_logit", ("extend_method", "position_embedding_type"): "position_embedding_type", ("init_method_std", "initializer_range"): "init_method_std", diff --git a/mindformers/parallel_core/utils/model_mixin.py b/mindformers/parallel_core/utils/model_mixin.py index 4524aaac77272863fd0aeb3060651442fba1427a..87bb49c2e4a790ad29a635681ee3064640f44129 100644 --- a/mindformers/parallel_core/utils/model_mixin.py +++ b/mindformers/parallel_core/utils/model_mixin.py @@ -451,12 +451,55 @@ class TrainModelMixin: raise ValueError(f"the length of cur_layer_linear_fc2_weights_dict is " f"{len(cur_layer_linear_fc2_weights_dict)}, can't stack them.") - def get_model_parameters(self): - """Get current rank trainable parameters in model .""" + def check_and_get_model(self): + """Check and get GPT model instance.""" if not hasattr(self, 'model'): raise RuntimeError("Mcore model definition should use the fixed paradigm: " "self.model = GPTModel(*args, **kwargs) definition. " "Currently, this attribute cannot be correctly recognized. " "Please modify the GPTModel definition method.") - model = getattr(self, 'model') + return getattr(self, 'model') + + def get_model_parameters(self): + """Get current rank trainable parameters in model .""" + model = self.check_and_get_model() return model.get_model_parameters() + + def make_model_muon_fns(self): + """Make model muon functions.""" + model = self.check_and_get_model() + return model.make_model_muon_fns() + + def get_muon_filter(self): + """Get muon filter.""" + model = self.check_and_get_model() + return model.get_muon_filter() + + def get_tp_dims(self, parameters): + """Get tensor parallel dimensions for parameters.""" + model = self.check_and_get_model() + return model.get_tp_dims(parameters) + + def get_op_groups_info(self, parameters, op_size, tp_group, op_group): + """Get operation groups information for parameters.""" + model = self.check_and_get_model() + return model.get_op_groups_info(parameters, op_size, tp_group, op_group) + + def get_parallel_config_for_muon(self): + """Get parallel configuration for Muon optimizer.""" + model = self.check_and_get_model() + return model.get_parallel_config_for_muon() + + def get_param_layer_indices(self, parameters): + """Get layer indices for parameters.""" + model = self.check_and_get_model() + return model.get_param_layer_indices(parameters) + + def apply_qk_clip_scaling(self, parameters, param_names, param_layers, + logit_threshold, split_fn, merge_fn): + """Apply QK clip scaling to parameters.""" + model = self.check_and_get_model() + return model.apply_qk_clip_scaling( + parameters, param_names, param_layers, + logit_threshold, split_fn, merge_fn + ) diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 9b0924360d77a228ca8d3d9eb1a3bb58b38c766a..68d1a43c99a3bab8b79d9aeb536fc0ad7c575c17 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -471,6 +471,7 @@ class BaseTrainer: if self.config.get("generation_config", None): self.config.model.generation_config = self.config.generation_config network = build_network(self.config.model, default_args=default_args) + self.real_model = network if hasattr(network, "check_pipeline_stage") and callable(network.check_pipeline_stage): network.check_pipeline_stage() return network @@ -643,7 +644,7 @@ class BaseTrainer: default_args = {"params": group_params, "learning_rate": self.lr_scheduler} if self.config.optimizer.type == "Muon": default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num - default_args["network"] = network + default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model self.optimizer = build_optim( self.config.optimizer, default_args=default_args) @@ -660,7 +661,7 @@ class BaseTrainer: default_args = {"params": group_params} if self.config.optimizer.type == "Muon": default_args["micro_batch_num"] = self.config.parallel_config.micro_batch_num - default_args["network"] = network + default_args["model"] = None if not hasattr(self, 'real_model') else self.real_model # Build optimizer with fixed learning rate self.optimizer = build_optim( self.config.optimizer, default_args=default_args) diff --git a/mindformers/trainer/trainer.py b/mindformers/trainer/trainer.py index 645df8949e66e324749eccf6716f930ce0e8473c..1b218ed711ad423b1e996d0b28e0433765a662d2 100644 --- a/mindformers/trainer/trainer.py +++ b/mindformers/trainer/trainer.py @@ -269,7 +269,7 @@ class Trainer: self.config = self._config_init(args, task_config) if self.config.optimizer.type == "Muon": - set_context(monitor_max_attention_logit=True) + self.config.model.model_config.monitor_max_attention_logit = True self._reassign_monitor_config() # build parallel config build_parallel_config(self.config) @@ -343,10 +343,10 @@ class Trainer: dump_local_norm=bool(monitor_config.get('local_norm_format')), dump_device_local_norm=bool(monitor_config.get('device_local_norm_format')) ) + if monitor_config.max_attention_logit_format: + self.config.model.model_config.monitor_max_attention_logit = True if monitor_config.local_loss_format: set_context(monitor_local_loss=True) - if monitor_config.max_attention_logit_format: - set_context(monitor_max_attention_logit=True) if monitor_config.device_local_loss_format: set_context(monitor_device_local_loss=True) for callback in self.config.callbacks: