From a73af11b9b7bad0dce753ebfbe57d050721c41f2 Mon Sep 17 00:00:00 2001 From: niujunhao Date: Wed, 29 Oct 2025 09:46:16 +0800 Subject: [PATCH] [feature] add grouped lr. --- .jenkins/test/config/dependent_packages.yaml | 4 +- mindformers/core/callback/callback.py | 115 +++++----- mindformers/trainer/base_trainer.py | 200 +++++++++++------- .../trainer/optimizer_grouped_parameters.py | 164 +++++++++----- mindformers/wrapper/wrapper.py | 132 +++++++----- .../test_optimizer_grouped_parameters.py | 149 +++++++++++++ 6 files changed, 535 insertions(+), 229 deletions(-) create mode 100644 tests/st/test_ut/test_optimizer_grouped_parameters.py diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 458da540f..b6b71a780 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,2 +1,4 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202509/20250920/master_20250920160018_8cfebd124882a46f9c59870f39eabc8c7ba57f8c_newest/' \ No newline at end of file + 'https://repo.mindspore.cn/mindspore/mindspore/version/202510/20251025/r2.7.2_20251025170006_0bd92f1f323c4b5ec120d5f1de47862c4088ca6f/' +ms_custom_ops: + 'https://repo.mindspore.cn/mindspore/ms_custom_ops/version/202510/20251025/master_20251025031507_443753ae14242e84f2e62e3727d7de72ba7887c3_newest/' \ No newline at end of file diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index 2b1bf9fd8..d71278552 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -83,6 +83,7 @@ from mindformers.parallel_core.training_graph.loss_func import ( ) from mindformers.checkpoint.checkpoint import AsyncSaveManager, CommonInfo, save_checkpoint +# pylint: disable=import-outside-toplevel __all__ = ['MFLossMonitor', 'CheckpointMonitor', 'SummaryMonitor', 'ProfileMonitor', 'EvalCallBack'] _cur_dir = os.getcwd() @@ -97,7 +98,7 @@ class AllReduceNet(Cell): """ def __init__(self, group_name): - super(AllReduceNet, self).__init__() + super().__init__() self.allreduce_sum = P.AllReduce(op=P.ReduceOp.SUM, group=group_name) self.add_flags(skip_auto_parallel_compile=True) @@ -122,6 +123,22 @@ def _get_separate_loss(): return lm_loss, aux_loss, mtp_loss +def _log_grouped_lr_info(): + """Log the current learning rate values for the default and grouped parameter sets.""" + from mindformers.trainer.optimizer_grouped_parameters import GROUPED_PARAMS + if not GROUPED_PARAMS: # Skip logging if no grouped parameters are registered + return + + # Retrieve the current default learning rate from parameter registry + default_lr = parameter_register.get("current_default_lr").asnumpy() + logger.info(f"default_lr: {default_lr:.6e}, equal to `lr:` above.") + + # Retrieve the current grouped learning rate from parameter registry + grouped_lr = parameter_register.get("current_grouped_lr").asnumpy() + for group_id, params in enumerate(GROUPED_PARAMS): + logger.info(f"group_{group_id}_lr: {grouped_lr[group_id]:.6e}, group_{group_id}_params: {params}") + + def _get_loss_output(output): """Get output of task for MFLossMonitor.""" overflow = False @@ -224,7 +241,7 @@ class MFLossMonitor(Callback): calculate_per_token_loss: bool = False, print_separate_loss: bool = False, **kwargs): - super(MFLossMonitor, self).__init__() + super().__init__() self.per_print_times = per_print_times self.learning_rate = deepcopy(learning_rate) self.last_print_time = 0 @@ -423,8 +440,8 @@ class MFLossMonitor(Callback): return False if not is_legacy_model() and self.is_moe_model: - logger.warning(f"Model Flops computation is not support when using GroupMatMul MoELayer, " - f"due to dynamic shape") + logger.warning("Model Flops computation is not support when using GroupMatMul MoELayer, " + "due to dynamic shape") return False self.current_phase = network.current_phase @@ -520,7 +537,7 @@ class MFLossMonitor(Callback): global_step = cur_step_num + (cur_epoch_num - 1) * steps_per_epoch if self.mf_calculated: throughput_per_npu = self.full_model_flops / per_step_seconds / 1e9 - throughput_info = ', train_throughput_per_npu: %.3fT' % (throughput_per_npu) + throughput_info = f', train_throughput_per_npu: {throughput_per_npu:.3f}T' if self.tensor_writer is not None: self.tensor_writer.add_scalar('model-flops-throughput-per-npu', float(throughput_per_npu), @@ -529,15 +546,15 @@ class MFLossMonitor(Callback): throughput_info = '' if cb_params.dataset_sink_mode: - loss_info = "loss: %5.6f, " % loss + loss_info = f"loss: {loss:5.6f}, " else: - loss_info = "loss:[%5.6f/%5.6f], " % (loss, np.mean(self.loss_list)) + loss_info = f"loss:[{loss:5.6f}/{np.mean(self.loss_list):5.6f}], " if self.print_separate_loss: - separate_loss = "lm_loss: %5.6f, " % main_loss + separate_loss = f"lm_loss: {main_loss[0]:5.6f}, " if self.is_moe_model and np.all(extra_loss > 0): - separate_loss += "load_balancing_loss: %5.6f, " % extra_loss + separate_loss += f"load_balancing_loss: {extra_loss[0]:5.6f}, " if self.is_mtp_model: - separate_loss += "mtp_loss: %5.6f, " % mtp_loss + separate_loss += f"mtp_loss: {mtp_loss[0]:5.6f}, " else: separate_loss = "" if current_lr is not None: @@ -556,9 +573,12 @@ class MFLossMonitor(Callback): int(per_step_seconds), overflow, scaling_sens, global_norm, throughput_info) # print progress bar - show_str = ('|%%-%ds|' % 50) % (int(50 * percent / 100) * "█") + bar = int(50 * percent / 100) * "█" + show_str = f"|{bar:<50}|" logger.info(" %4.1f%% %s %.5f samples/s/p %s }", percent, show_str, throughput, timedelta(seconds=int(time_remain))) + # log grouped lr info if enabled + _log_grouped_lr_info() # write tensorboard if self.tensor_writer is not None: @@ -674,7 +694,7 @@ class TrainingStateMonitor(Callback): use_skip_data_by_global_norm: bool = False, embedding_size: int = 4096, use_local_norm: bool = False): - super(TrainingStateMonitor, self).__init__() + super().__init__() if not (isinstance(step_interval, int) and step_interval > 0): logger.warning(f"The value of 'monitor_config.step_interval' should be positive integer, " f"but get {step_interval}. Use default value: 1.") @@ -831,7 +851,7 @@ class TrainingStateMonitor(Callback): parent_dirs = os.path.dirname(self.global_norm_record_path) if not os.path.exists(parent_dirs): os.makedirs(parent_dirs) - with open(self.global_norm_record_path, 'w') as file: + with open(self.global_norm_record_path, 'w', encoding="utf-8") as file: json.dump(self.abnormal_global_norms, file) set_safe_mode_for_file_or_dir(self.global_norm_record_path) logger.info(f"Current global norm {global_norm} is greater equal than " @@ -897,7 +917,7 @@ class TrainingStateMonitor(Callback): if self.print_struct is None: self.print_struct = False - if not (isinstance(self.target, list) and self.target and all([isinstance(i, str) for i in self.target])): + if not (isinstance(self.target, list) and self.target and all(isinstance(i, str) for i in self.target)): raise TypeError("The value of 'target' should be a list of str.") if not isinstance(self.invert, bool): raise TypeError("The value of 'invert' should be bool.") @@ -913,7 +933,7 @@ class TrainingStateMonitor(Callback): if self.global_norm_record_path and os.path.exists(self.global_norm_record_path): # the data format might be like {"300": [3.3], "600": [4.1, 4.2],} # because json cannot use number as key, we convert it to string - with open(self.global_norm_record_path, 'r') as file: + with open(self.global_norm_record_path, 'r', encoding="utf-8") as file: self.abnormal_global_norms = json.load(file) def _check_attr_formats(self, attr): @@ -1271,7 +1291,7 @@ class CheckpointMonitor(ModelCheckpoint): self.save_checkpoint_path = save_checkpoint_path self.need_remove_redundancy = remove_redundancy - prefix = prefix + "_rank_{}".format(self.rank_id) + prefix = prefix + f"_rank_{self.rank_id}" self.global_batch_size = global_batch_size # this list records parameters which will be ignored when saving ckpt. @@ -1316,8 +1336,7 @@ class CheckpointMonitor(ModelCheckpoint): format=checkpoint_format, exception_save=exception_save, remove_redundancy=remove_redundancy) - super(CheckpointMonitor, self).__init__(prefix, ckpt_directory if self.use_legacy_format else None, - config=config_ck) + super().__init__(prefix, ckpt_directory if self.use_legacy_format else None, config=config_ck) self.meta_json = os.path.join(self._directory, "meta.json") if self._config.async_save: self.last_epoch_num = None @@ -1378,8 +1397,8 @@ class CheckpointMonitor(ModelCheckpoint): keys = list(self.save_info_list.keys()) for record_step in keys: self.print_savetime(record_step, cb_params.batch_num) - if not any([self.save_info_list[record_step][key]['ckpt_file_path'] - for key in ['ckpt', 'network', 'trainable_params']]): + if not any(self.save_info_list[record_step][key]['ckpt_file_path'] + for key in ['ckpt', 'network', 'trainable_params']): self.save_info_list.pop(record_step) if self._config.async_save and not ms.async_ckpt_thread_status() and \ @@ -1469,11 +1488,11 @@ class CheckpointMonitor(ModelCheckpoint): } all_step_health_data = [] if os.path.exists(dump_health_json_path): - with open(dump_health_json_path, 'r') as file: + with open(dump_health_json_path, 'r', encoding="utf-8") as file: data = json.load(file) all_step_health_data = list(data) all_step_health_data.append(health_step_data) - with open(dump_health_json_path, 'w') as file: + with open(dump_health_json_path, 'w', encoding="utf-8") as file: json.dump(all_step_health_data, file, indent=4) set_safe_mode_for_file_or_dir(dump_health_json_path) @@ -1859,7 +1878,7 @@ class ProfileMonitor(Callback): start_profile=True, profile_rank_ids=None, profile_pipeline=False, profile_communication=False, profile_memory=False, config=None, profiler_level=0, with_stack=False, data_simplification=True, mstx=False, **kwargs): - super(ProfileMonitor, self).__init__() + super().__init__() self.mstx_range_id = None self.mstx_enabled = not _check_mspti_is_on() self.stop_step = stop_step @@ -1891,8 +1910,8 @@ class ProfileMonitor(Callback): if not output_path: output_path = get_output_subpath('profile', rank_id) else: - output_path = os.path.join(output_path, 'profile', 'rank_{}'.format(rank_id)) - logger.info("Profile save path: %s", output_path) + output_path = os.path.join(output_path, 'profile', f'rank_{rank_id}') + logger.info(f"Profile save path: {output_path}") if ms.get_context("device_target") == "GPU" and profile_memory: logger.warning("The parameter profile_memory is not supported on GPU currently, " @@ -2168,7 +2187,7 @@ class ColdHotExpertMonitor(Callback): self.local_expert_num = self.expert_num // self.ep start_index = (self.rank_id // self.mp) * self.local_expert_num end_index = start_index + self.local_expert_num - self.local_expert_index = [i for i in range(start_index, end_index)] + self.local_expert_index = list(range(start_index, end_index)) self.rank_size = int(os.getenv("RANK_SIZE")) def on_train_step_end(self, run_context): @@ -2541,14 +2560,14 @@ class MoEDropRateCallback(Callback): if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) - logger.info("layer: %d, drop_rate: %.5f" % (i, droprate)) + logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") else: if hasattr(network.model.layers[i].feed_forward, "router"): fi = network.model.layers[i].feed_forward.router.router.fi_parameter.value() if fi.sum() > 0: delta = fi - self.capacity_factor_over_expert_num droprate = ms.ops.sum(delta * (delta > 0)) - logger.info("layer: %d, drop_rate: %.5f" % (i, droprate)) + logger.info(f"layer: {i}, drop_rate: {droprate:.5f}") def on_train_step_end(self, run_context): """get expert drop rate at the end of step.""" @@ -2602,7 +2621,7 @@ class StressTestModelMonitor(Callback): check_stresslog_interval_time=60): logger.warning('StressTestModelMonitor serves as an experimental interface and its functionality is ' 'not yet stable.') - super(StressTestModelMonitor, self).__init__() + super().__init__() self.interval_steps = interval_steps self.last_checked_step = 0 @@ -2676,20 +2695,19 @@ class StressTestModelMonitor(Callback): log_file_path = os.path.join(saved_dir, "worker_0.log") # Start the subprocess command = shlex.split(command) - result_1 = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - # Monitor the log file - while result_1.poll() is None: # While the subprocess is running - time.sleep(self.check_stresslog_interval_time) - log_msg = self.readlog(log_file_path) - logger.info(f"Checking stress test log every {self.check_stresslog_interval_time} seconds") - logger.info(f"Current state of stress_test: {log_msg}") - - # Once the subprocess has finished, check the result - if result_1.returncode != 0: - logger.warning(f"An error occurred while running the stress test model on rank {rank_id}: \ - {result_1.stderr.read().decode('utf-8')}") - logger.warning(f"Check the sub task workers log for rank {rank_id} for more details.") + with subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as result_1: + # Monitor the log file + while result_1.poll() is None: # While the subprocess is running + time.sleep(self.check_stresslog_interval_time) + log_msg = self.readlog(log_file_path) + logger.info(f"Checking stress test log every {self.check_stresslog_interval_time} seconds") + logger.info(f"Current state of stress_test: {log_msg}") + + # Once the subprocess has finished, check the result + if result_1.returncode != 0: + logger.warning(f"An error occurred while running the stress test model on rank {rank_id}: \ + {result_1.stderr.read().decode('utf-8')}") + logger.warning(f"Check the sub task workers log for rank {rank_id} for more details.") barrier() logger.info("Stress tests ended, now starting to collect and compare results") @@ -2700,6 +2718,7 @@ class StressTestModelMonitor(Callback): "so only the last step result is compared.") else: interval_results = None + subtask_global_step_num = None logger.info(f"Test results are compared every {self.compare_interval_steps} steps") for i in range(self.worker_num): if get_rank() % self.worker_num == i: @@ -2755,7 +2774,7 @@ class StressTestModelMonitor(Callback): """Extract the last step's results from the log file.""" loss_value = None global_norm_value = None - with open(log_file, 'r') as file: + with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in reversed(lines): if "INFO - {" in line: @@ -2771,7 +2790,7 @@ class StressTestModelMonitor(Callback): last_recorded_step = 0 results = [] steps_per_epoch = None - with open(log_file, 'r') as file: + with open(log_file, 'r', encoding="utf-8") as file: lines = file.readlines() for line in lines: if "INFO - {" in line: @@ -2855,7 +2874,7 @@ class StressTestModelMonitor(Callback): """ Search for the latest line indicating training has started, based on key identifiers. """ - with open(file_path, 'r', errors='ignore') as f: + with open(file_path, 'r', errors='ignore', encoding="utf-8") as f: lines = f.readlines() # Define the keywords indicating training has started @@ -2894,7 +2913,7 @@ class SDCMonitor(Callback): strike_num: int = 3, checksum_time: int = 5, checksum_cooldown_time: int = 180): - super(SDCMonitor, self).__init__() + super().__init__() logger.warning('SDCMonitor serves as an experimental interface and its functionality is not yet stable.') npu_asd_enable = int(os.getenv('NPU_ASD_ENABLE', '0')) @@ -2946,7 +2965,7 @@ class SDCMonitor(Callback): file_path = os.path.join(self.device_log_path, file) if not os.path.exists(file_path): continue - with open(file_path, 'r') as f: + with open(file_path, 'r', encoding="utf-8") as f: logs = f.read() error_logs = self.silent_check_error_pattern.findall(logs) for log in error_logs: diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 0bed5b38b..9365f13c1 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -81,6 +81,7 @@ from .utils import set_seed, check_train_data_loader_type, \ check_eval_data_loader_type, check_optimizer_and_lr_type, check_wrapper_config from ..version_control import check_tft_valid, check_tre_valid, check_tsp_valid, check_is_reboot_node +# pylint: disable=import-outside-toplevel SUPPORT_TASKS = MindFormerBook().get_trainer_support_task_list() SUPPORT_MODEL_NAMES = MindFormerBook().get_model_name_support_list() SUPPORT_PIPELINES = MindFormerBook().get_pipeline_support_task_list() @@ -105,12 +106,12 @@ class BaseTrainer: def __init__(self, task: str = None, model_name: str = None): host_name_output = subprocess.run(['hostname'], shell=False, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, encoding='utf-8') + stderr=subprocess.PIPE, encoding='utf-8', check=True) host_ip_output = subprocess.run(['hostname', '-I'], shell=False, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, encoding='utf-8') + stderr=subprocess.PIPE, encoding='utf-8', check=True) host_name = host_name_output.stdout.strip() host_ip = host_ip_output.stdout.strip().split(' ')[0] - logger.info("host_name: %s, host_ip: %s" % (host_name, host_ip)) + logger.info(f"host_name: {host_name}, host_ip: {host_ip}") if model_name is None: model_name = "model name unspecified." @@ -140,6 +141,9 @@ class BaseTrainer: self.network_delay_inited = False self.optimizer_delay_inited = False + self.lr_scheduler = None + self.grouped_lr_scheduler = None + if task not in SUPPORT_TASKS.keys(): logger.warning("Input task name is not in the supported list or unspecified.") @@ -579,93 +583,136 @@ class BaseTrainer: self.config.processor.image_processor, default_args=default_args) return self.image_processor - def create_optimizer_scheduler(self, network, model_params: set, layer_scale=False): + def _create_grouped_lr_scheduler(self, learning_scale, scale_factor): + """ + Create learning rate (LR) schedulers for both default and parameter-grouped configurations. + """ + default_config = self.config.grouped_lr_schedule.default + default_lr_scheduler = self.create_lr_scheduler(default_config, learning_scale, scale_factor) + + grouped_lr_scheduler = [] + grouped_config = self.config.grouped_lr_schedule.grouped + + # Iterate over each grouped LR configuration + for lr_config in grouped_config: + params = lr_config.pop('params', None) + if not params: + raise ValueError("If using grouped_lr_schedule, 'params' must be set correctly in `grouped`.") + + lr_config = MindFormerConfig(**lr_config) + lr_scheduler = self.create_lr_scheduler(lr_config, learning_scale, scale_factor) + grouped_lr_scheduler.append({ + 'params': params, + 'lr_scheduler': lr_scheduler, + 'lr_config': lr_config + }) + + return default_lr_scheduler, grouped_lr_scheduler + + def create_optimizer_scheduler(self, network, model_params: set): """Create the optimizer for training.""" - logger.info(".........Build Optimizer From Config..........") - # learning rate scale for multi-nodes training + logger.info("..........Build Optimizer From Config..........") + # Get learning rate scaling settings learning_scale = self.config.lr_scale scale_factor = self.config.lr_scale_factor - # build learning rate schedule - lr_schedule = self.create_lr_scheduler(learning_scale, scale_factor) - - weight_decay = self.config.optimizer.weight_decay if self.config.optimizer.weight_decay else 0. - layer_decay = self.config.layer_decay if self.config.layer_decay else 1.0 - group_params = get_optimizer_grouped_parameters(network, - weight_decay, - lr_schedule, - layer_scale=layer_scale, - layer_decay=layer_decay, - optimizer_type=self.config.optimizer.type, - model_params=model_params) - if lr_schedule is not None: + # Build the learning rate scheduler + if self.config.grouped_lr_schedule is not None: + # Create grouped LR scheduler if configured (e.g., different LRs for different parameter groups) + self.lr_scheduler, self.grouped_lr_scheduler = self._create_grouped_lr_scheduler( + learning_scale, scale_factor) + else: + # Create standard LR scheduler + self.lr_scheduler = self.create_lr_scheduler( + self.config.lr_schedule, learning_scale, scale_factor) + + # Build optimizer parameter groups (apply weight decay, LR groups, etc.) + group_params = get_optimizer_grouped_parameters( + model=network, + weight_decay=self.config.optimizer.weight_decay, + dynamic_lr_schedule=self.lr_scheduler, + grouped_lr_schedule=self.grouped_lr_scheduler, + model_params=model_params, + optimizer_type=self.config.optimizer.type, + layer_scale=self.config.layer_scale, + layer_decay=self.config.layer_decay + ) + + # Build optimizer with dynamic lr_scheduler if available + if self.lr_scheduler is not None: self.optimizer = build_optim( self.config.optimizer, - default_args={"params": group_params, - "learning_rate": lr_schedule}) + default_args={"params": group_params, "learning_rate": self.lr_scheduler}) else: + # Otherwise, create lr_scheduler with static learning rate from config if self.config.optimizer.learning_rate is None: - raise ValueError("learning_rate must be input") - self.config.optimizer.learning_rate = self.learning_rate_scale( - self.config.optimizer.learning_rate, scale_factor) \ - if learning_scale and scale_factor is not None else self.config.optimizer.learning_rate + raise ValueError("") + + learning_rate = self.config.optimizer.learning_rate + if learning_scale and scale_factor is not None: + # reset learning_rate in optimizer with scale_factor if learning_scale is True + self.config.optimizer.learning_rate = self.learning_rate_scale(learning_rate, scale_factor) + + # Build optimizer with fixed learning rate self.optimizer = build_optim( - self.config.optimizer, - default_args={"params": group_params}) + self.config.optimizer, default_args={"params": group_params}) + return self.optimizer - def create_optimizer_scheduler_without_param_init(self, network, model_params: set, layer_scale=False): + def create_optimizer_scheduler_without_param_init(self, network, model_params: set): """Create the optimizer for training without initialize parameters.""" with no_init_parameters(): - optimizer = self.create_optimizer_scheduler(network=network, - model_params=model_params, - layer_scale=layer_scale) + optimizer = self.create_optimizer_scheduler(network=network, model_params=model_params) logger.info("Parameters are not initialized during optimizer initialization.") self.optimizer_delay_inited = True return optimizer - def create_lr_scheduler(self, learning_scale: bool = False, scale_factor: int = 256): + def create_lr_scheduler(self, lr_config: dict, learning_scale: bool = False, scale_factor: int = 256): """Create the learning rate scheduler.""" - logger.info(".........Build LR Schedule From Config..........") + logger.info("..........Build LR Schedule From Config..........") train_data_size = self.get_train_data_size() + warmup_lr_init = None - if self.config.lr_schedule: - warmup_epochs = self.config.lr_schedule.pop("warmup_epochs", None) - warmup_lr_init = self.config.lr_schedule.get("warmup_lr_init", None) - - if warmup_epochs is not None: - if not isinstance(warmup_epochs, int): - raise ValueError(f"The type of warmup_epochs must be int, but got type {type(warmup_epochs)}.") - if warmup_epochs < 0: - raise ValueError(f"The value of warmup_epochs must be non-negative integer, " - f"but got {warmup_epochs}.") - - if not self.config.runner_config.sink_mode: - total_steps = int(self.config.runner_config.epochs * train_data_size) - else: + if lr_config: + warmup_epochs = lr_config.warmup_epochs + warmup_ratio = lr_config.warmup_ratio + warmup_lr_init = lr_config.warmup_lr_init + + # Calculate total training steps depending on sink_mode + if self.config.runner_config.sink_mode: total_steps = int(self.config.runner_config.epochs * self.config.runner_config.sink_size) + else: + total_steps = int(self.config.runner_config.epochs * train_data_size) - if warmup_epochs is not None and self.config.lr_schedule.warmup_ratio is not None: - logger.warning("warmup_epochs and warmup_ratio are set simultaneously," - "warmup_ratio takes precedence.") - warmup_epochs = None + # Set total_steps in `lr_schedule` if not explicitly defined + if lr_config.total_steps is None or lr_config.total_steps == -1: + lr_config.total_steps = total_steps + else: + lr_config.total_steps = int(lr_config.total_steps) + + # Compute warmup steps if warmup is enabled + if warmup_epochs: + if not isinstance(warmup_epochs, int) or warmup_epochs < 0: + raise ValueError("`warmup_epochs` must be a non-negative integer.") + if not warmup_ratio: + raise ValueError("`warmup_ratio` must be specified when warmup is enabled.") + + warmup_steps = int(warmup_epochs * train_data_size) + lr_config.warmup_steps = warmup_steps + logger.info( + f"Get warmup_steps({warmup_steps}) = " + f"int(warmup_epochs({warmup_epochs}) * train_data_size({train_data_size}))" + ) - if warmup_epochs is not None: - logger.warning("warmup_epochs was set in lr_schedule," - "it will multiply the data size to represent the warmup steps") - self.config.lr_schedule.warmup_steps = int(warmup_epochs * train_data_size) + if learning_scale and scale_factor is not None: + lr_config.learning_rate = self.learning_rate_scale(lr_config.learning_rate, scale_factor) - self.config.lr_schedule.total_steps = total_steps \ - if self.config.lr_schedule.total_steps is None or self.config.lr_schedule.total_steps == -1 \ - else int(self.config.lr_schedule.total_steps) + # Build learning rate scheduler from configuration + lr_scheduler = build_lr(lr_config) - self.config.lr_schedule.learning_rate = self.learning_rate_scale( - self.config.lr_schedule.learning_rate, scale_factor) \ - if learning_scale and scale_factor is not None else self.config.lr_schedule.learning_rate - lr_schedule = build_lr(self.config.lr_schedule) - if lr_schedule and hasattr(lr_schedule, "warmup_lr_init") and warmup_lr_init is None: - logger.info(f"warmup_lr_init is not set. The default value {lr_schedule.warmup_lr_init} will be applied.") - return lr_schedule + if lr_scheduler and hasattr(lr_scheduler, "warmup_lr_init") and warmup_lr_init is None: + logger.info(f"warmup_lr_init is not set. The default value {lr_scheduler.warmup_lr_init} will be applied.") + return lr_scheduler def create_model_wrapper(self, network, optimizer): """Create the model wrapper for training.""" @@ -683,8 +730,9 @@ class BaseTrainer: "calculate_per_token_loss": calculate_per_token_loss, "global_norm_spike_threshold": global_norm_spike_threshold, "print_separate_loss": print_separate_loss, - "use_skip_data_by_global_norm": use_skip_data_by_global_norm - }) + "use_skip_data_by_global_norm": use_skip_data_by_global_norm, + "lr_scheduler": self.lr_scheduler, + "grouped_lr_scheduler": self.grouped_lr_scheduler}) return model_wrapper def create_callbacks(self, default_args: dict = None): @@ -903,7 +951,7 @@ class BaseTrainer: Postprocess the training dataset after construction. Mainly used to adjust dataset size for special dataloaders. """ - dataloader_config = config.train_dataset.get('data_loader', dict()) + dataloader_config = config.train_dataset.get('data_loader', {}) dataloader_type = dataloader_config.get('type') # Special handling for BlendedMegatronDatasetDataLoader @@ -926,7 +974,7 @@ class BaseTrainer: # Check dataset sink mode with dataset broadcast optimization level self._check_sink_mode_with_ds_broadcast(config) - dataloader_config = config.train_dataset.get('data_loader', dict()) + dataloader_config = config.train_dataset.get('data_loader', {}) dataloader_type = dataloader_config.get('type') # Case 1: BlendedMegatronDatasetDataLoader @@ -990,7 +1038,7 @@ class BaseTrainer: """ cur_rank = get_rank() - src_strategy_files = sorted([f for f in os.listdir(config.src_strategy_path_or_dir)]) + src_strategy_files = sorted(list(os.listdir(config.src_strategy_path_or_dir))) if len(src_strategy_files) - 1 < cur_rank: raise ValueError(f" rank {cur_rank} src_strategy is not exist") src_strategy_file = os.path.join(config.src_strategy_path_or_dir, src_strategy_files[cur_rank]) @@ -1100,9 +1148,9 @@ class BaseTrainer: logger.info("..............Start resume checkpoint path from strategy..............") resume_ckpt_path = self.resume_ckpt_path_with_strategy(config) if resume_ckpt_path is None: - raise ValueError("Try to resume from checkpoints with strategy in directory '{}' failed, " - "please specify load_checkpoint to specific checkpoint file to resume training." - .format(config.load_checkpoint)) + raise ValueError(f"Try to resume from checkpoints with strategy in directory " + f"'{config.load_checkpoint}' failed, please specify load_checkpoint to " + f"specific checkpoint file to resume training.") config.load_checkpoint = resume_ckpt_path load_resume_context_from_checkpoint(config, dataset) resume_dict = { @@ -1211,11 +1259,9 @@ class BaseTrainer: logger.info(".........Build Optimizer For Train..........") if optimizer is None: if config.load_checkpoint: - optimizer = self.create_optimizer_scheduler_without_param_init(network, - model_params, - layer_scale=config.layer_scale) + optimizer = self.create_optimizer_scheduler_without_param_init(network, model_params) else: - optimizer = self.create_optimizer_scheduler(network, model_params, layer_scale=config.layer_scale) + optimizer = self.create_optimizer_scheduler(network, model_params) # build model wrapper logger.info(".........Build Running Wrapper From Config For Train..........") diff --git a/mindformers/trainer/optimizer_grouped_parameters.py b/mindformers/trainer/optimizer_grouped_parameters.py index b614550d0..f2b8734e9 100644 --- a/mindformers/trainer/optimizer_grouped_parameters.py +++ b/mindformers/trainer/optimizer_grouped_parameters.py @@ -16,95 +16,161 @@ import json from typing import Optional +from collections import defaultdict +from fnmatch import fnmatch from mindspore.nn import Cell from mindspore.nn.learning_rate_schedule import LearningRateSchedule from mindformers.models import PreTrainedModel -from mindformers.core.lr import LearningRateWiseLayer from mindformers.tools.logger import logger -from .utils import check_keywords_in_name +from mindformers.trainer.utils import check_keywords_in_name + +# Global list to store grouped parameter names in optimizer +GROUPED_PARAMS = [] + def filter_current_stage_parameters(model, model_params): - """Get current rank trainable parameters in model while use PMA.""" + """ + Disable gradient updates for parameters that are not included in the + current training stage (used in PMA training). + """ if not model_params: - raise ValueError("The model got empty trainable parameters, " - "please check the get_model_parameters method.") + raise ValueError( + "The model got empty trainable parameters, " + "please check the get_model_parameters method." + ) + + # Iterate over all submodules (cells) and their names for _, cell in model.cells_and_names(): for param in cell.trainable_params(): if param not in model_params: param.requires_grad = False + +def _get_gouped_lr_map(model, grouped_lr_scheduler=None): + """ + Build parameter-to-group and group-to-learning-rate mappings + based on grouped learning rate scheduler configuration. + """ + param_group_map = {} + grouped_lr_map = defaultdict(dict) + if not grouped_lr_scheduler: + return param_group_map, grouped_lr_map + + # Map parameter name patterns to group IDs + group_map = {} + for group_id, group_dict in enumerate(grouped_lr_scheduler): + params = group_dict.get('params', None) + lr_scheduler = group_dict.get('lr_scheduler') + lr_config = group_dict.get('lr_config') + + # Assign each param pattern to a group + for param in params: + group_map[param] = group_id + + # Store LR scheduler instance and its config + grouped_lr_map[group_id]['instance'] = lr_scheduler + grouped_lr_map[group_id]['config'] = lr_config + + # Initialize the global grouped parameter tracker + global GROUPED_PARAMS + GROUPED_PARAMS = [[] for _ in range(len(grouped_lr_scheduler))] + + # Match actual parameter names to group patterns + for param in model.trainable_params(): + for grouped_param_name in list(group_map.keys()): + group_id = group_map.get(grouped_param_name) + # Match exact or wildcard parameter names + if grouped_param_name in param.name or fnmatch(param.name, grouped_param_name): + param_group_map[param.name] = group_id + GROUPED_PARAMS[group_id].append(param.name) + break + return param_group_map, grouped_lr_map + + def get_optimizer_grouped_parameters(model: Optional[PreTrainedModel] = None, weight_decay: float = 0.0, dynamic_lr_schedule: Optional[LearningRateSchedule] = None, + optimizer_type: str = 'AdamW', + model_params: set = None, + grouped_lr_schedule: dict = None, layer_scale: bool = False, - layer_decay: float = 1.0, - optimizer_type="AdamW", - model_params=None): - """Get grouped parameters of the network for training.""" + layer_decay: float = 1.0,): + """ + Build optimizer parameter groups with appropriate weight decay, + learning rate scheduling, and optional parameter grouping. + """ if not isinstance(model, (Cell, PreTrainedModel)): raise TypeError(f"model type should be PreTrainedModel, but get {type(model)}") + if layer_scale: + raise ValueError("layer_scale is not supported currently.") - skip_params = {} - skip_keywords = {} + no_wd_params = {} + no_wd_keywords = {} if hasattr(model, 'no_weight_decay'): - skip_params = model.no_weight_decay() - logger.info('No weight decay: %s', skip_params) + no_wd_params = model.no_weight_decay() + logger.info(f'Get no weight decay params: {no_wd_params}') if hasattr(model, 'no_weight_decay_keywords'): - skip_keywords = model.no_weight_decay_keywords() - logger.info('No weight decay keywords: %s', skip_keywords) + no_wd_keywords = model.no_weight_decay_keywords() + logger.info(f'Get no weight decay keywords: {no_wd_keywords}') - decay_parameters_names = [] + # set default values if not provided + if not weight_decay: + weight_decay = 0.0 + if not layer_decay: + layer_decay = 1.0 + # PMA optimizer requires filtering of stage-specific parameters if optimizer_type in ("PmaAdamW", "FusedPmaAdamW"): filter_current_stage_parameters(model, model_params) - for param in model.trainable_params(): - if skip_params or skip_keywords: - if param.name in skip_params: - continue - if check_keywords_in_name(param.name, skip_keywords): - continue - elif len(param.shape) == 1 or param.name.endswith(".bias"): - continue - decay_parameters_names.append(param.name) - - if dynamic_lr_schedule is not None: - if layer_scale: - logger.warning("if use dynamic_lr_schedule and layer_scale, they will be reset and invalid.") - layer_scale = False - dynamic_lr_schedule = None - else: - logger.warning("dynamic_lr_schedule will be reset and invalid when layer_scale is False.") - dynamic_lr_schedule = None + # Build mapping from params to LR groups + param_group_map, lr_scheduler_map = _get_gouped_lr_map(model, grouped_lr_schedule) + parameter_group_names = {} # For logging + parameter_group_vars = {} # Actual optimizer groups - parameter_group_names = {} - parameter_group_vars = {} - scale = 1. + # Iterate over trainable parameters and assign them to groups for param in model.trainable_params(): - if param.name in decay_parameters_names: - group_name = 'decay' - weight_decay_ = weight_decay + param_name = param.name + + no_wd = ( + len(param.shape) == 1 + or param_name.endswith(".bias") + or param_name in no_wd_params + or check_keywords_in_name(param_name, no_wd_keywords) + ) + if no_wd: + wd_mul = 0.0 + group_name = 'no_weight_decay' else: - group_name = 'no_decay' - weight_decay_ = 0. + wd_mul = 1.0 + group_name = 'weight_decay' + group_id = None + if param_name in param_group_map: + group_id = param_group_map.get(param_name) + group_name = f"group_{group_id}_{group_name}" + + # Initialize group if not exists if group_name not in parameter_group_names: parameter_group_names[group_name] = { - "weight_decay": weight_decay_, + "weight_decay": wd_mul * weight_decay, "params": [], } parameter_group_vars[group_name] = { - "weight_decay": weight_decay_, + "weight_decay": wd_mul * weight_decay, "params": [], } - if isinstance(dynamic_lr_schedule, LearningRateSchedule): - if layer_scale: - parameter_group_vars[group_name]["lr"] = LearningRateWiseLayer(dynamic_lr_schedule, scale) - else: - parameter_group_vars[group_name]["lr"] = dynamic_lr_schedule + # Attach LR scheduler if group-specific + if group_id is not None: + cur_lr = lr_scheduler_map[group_id].get('instance') + cur_lr_config = lr_scheduler_map[group_id].get('config') + parameter_group_vars[group_name]["lr"] = cur_lr + parameter_group_names[group_name]["lr_config"] = cur_lr_config + + # Append parameter to its group parameter_group_vars[group_name]["params"].append(param) parameter_group_names[group_name]["params"].append(param.name) diff --git a/mindformers/wrapper/wrapper.py b/mindformers/wrapper/wrapper.py index 855713885..6f009d021 100644 --- a/mindformers/wrapper/wrapper.py +++ b/mindformers/wrapper/wrapper.py @@ -18,13 +18,6 @@ import os import shutil from copy import deepcopy -from mindformers.core.clip_grad import ClipGradNorm -from mindformers.core.context.build_context import is_legacy_model -from mindformers.tools.register import MindFormerRegister, MindFormerModuleType -from mindformers.tools.utils import get_real_rank -from mindformers.utils.parameter_register import parameter_register -from mindformers.version_control import get_identity - from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.parallel._utils import _get_enable_parallel_optimizer @@ -40,6 +33,13 @@ from mindspore.ops import operations as P from mindspore.ops.auto_generate import SumExt from mindspore.ops._grad_experimental.grad_comm_ops import get_squared_device_local_norm_param +from mindformers.core.clip_grad import ClipGradNorm +from mindformers.core.context.build_context import is_legacy_model +from mindformers.tools.register import MindFormerRegister, MindFormerModuleType +from mindformers.tools.utils import get_real_rank +from mindformers.utils.parameter_register import parameter_register +from mindformers.version_control import get_identity + __all__ = ['MFTrainOneStepCell', 'MFPipelineWithLossScaleCell', 'PipelineCellWithTwoOutput', 'GradAccumulationCellWithTwoOutput'] @@ -114,6 +114,21 @@ def _reset_accu_gbs_fi(network): raise NotImplementedError(f"network: {network} does not Implemented function `reset_accu_gbs_fi`") +def _check_network_with_micro_size(cls_name, network, micro_size): + """ + Check the validity of network and micro_size parameters for cls_name + """ + if not isinstance(network, nn.Cell): + raise TypeError(f"For `{cls_name}`, the argument 'network' must cell type, " + f"but got the type : {type(network)}.") + if not isinstance(micro_size, int): + raise TypeError(f"For `{cls_name}`, the argument 'micro_size' must be integer, " + f"but got the type : {type(micro_size)}.") + if micro_size <= 0: + raise ValueError(f"For `{cls_name}`, the argument 'micro_size' must be large than 0, " + f"but got {micro_size}.") + + @_grad_scale.register("Tensor", "Tensor") def tensor_grad_scale(scale, grad): return F.cast(grad, mstype.float32) * reciprocal(scale) @@ -153,7 +168,7 @@ def _get_size(grad): class LocalNorm(nn.Cell): def __init__(self): - super(LocalNorm, self).__init__() + super().__init__() self.hyper_map = C.HyperMap() def construct(self, grads): @@ -249,7 +264,7 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): **kwargs): if isinstance(scale_sense, (int, float)): scale_sense = Tensor(scale_sense) - super(MFTrainOneStepCell, self).__init__(network, optimizer, scale_sense) + super().__init__(network, optimizer, scale_sense) self.use_clip_grad = use_clip_grad self.clip_grad_norm = ClipGradNorm(max_norm=max_grad_norm) self.parallel_config = kwargs.pop("parallel_config", None) @@ -294,6 +309,20 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): transformer_config.num_layers ) + # Get grouped LR schedulers from base_trainer + grouped_lr_scheduler = kwargs.get('grouped_lr_scheduler') + self.use_grouped_lr = grouped_lr_scheduler is not None + if self.use_grouped_lr: + # Register default_lr + self.default_lr = kwargs.get('lr_scheduler') + self.default_lr_parameter = parameter_register.register( + "current_default_lr", Tensor(0., mstype.float32)) + # Register grouped_lr + self.grouped_lr = nn.CellList( + [deepcopy(current_lr.get('lr_scheduler')) for current_lr in grouped_lr_scheduler]) + self.grouped_lr_parameter = parameter_register.register( + "current_grouped_lr", Tensor([0.] * len(self.grouped_lr), mstype.float32)) + def construct(self, *inputs): """forward and backward.""" scaling_sens = self.scale_sense @@ -326,7 +355,13 @@ class MFTrainOneStepCell(nn.TrainOneStepWithLossScaleCell): learning_rate = self.learning_rate if self.optimizer.dynamic_lr: - if self.optimizer.is_group_lr: + if self.use_grouped_lr: + # Get default_lr, grouped_lr and update parameters registered + learning_rate = self.default_lr(self.optimizer.global_step).reshape(()) + self.default_lr_parameter = learning_rate + for group_id, current_lr in enumerate(self.grouped_lr): + self.grouped_lr_parameter[group_id] = current_lr(self.optimizer.global_step).reshape(()) + elif self.optimizer.is_group_lr: learning_rate = self.learning_rate[-1](self.optimizer.global_step).reshape(()) else: learning_rate = self.learning_rate(self.optimizer.global_step).reshape(()) @@ -414,13 +449,15 @@ class DataOrderWrapperCell(nn.Cell): """For passing parameters in lexicographical order.""" def __init__(self, construct_args_key, network): - super(DataOrderWrapperCell, self).__init__(auto_prefix=False) + super().__init__(auto_prefix=False) self.construct_args_key = construct_args_key self.network = network def construct(self, *inputs): """The construct processes of inputs in lexicographical order.""" - key_inputs = {key: val for key, val in zip(self.construct_args_key, inputs)} + key_inputs = {} + for key, val in zip(self.construct_args_key, inputs): + key_inputs[key] = val return self.network(**key_inputs) @@ -472,15 +509,7 @@ class GradAccumulationCellWithTwoOutput(nn.Cell): self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] - if not isinstance(network, nn.Cell): - raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'network' must cell type, " - "but got the type : {}.".format(type(network))) - if not isinstance(micro_size, int): - raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be integer, " - "but got the type : {}.".format(type(micro_size))) - if micro_size <= 0: - raise ValueError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be large than 0, " - "but got {}.".format(micro_size)) + _check_network_with_micro_size('GradAccumulationCellWithTwoOutput', network, micro_size) for i in range(micro_size): micro_input = _MicroBatch(micro_size) micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size) @@ -532,15 +561,7 @@ class GradAccumulationCellWithMultiOutputs(nn.Cell): self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] - if not isinstance(network, nn.Cell): - raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'network' must cell type, " - "but got the type : {}.".format(type(network))) - if not isinstance(micro_size, int): - raise TypeError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be integer, " - "but got the type : {}.".format(type(micro_size))) - if micro_size <= 0: - raise ValueError("For 'GradAccumulationCellWithTwoOutput', the argument 'micro_size' must be large than 0, " - "but got {}.".format(micro_size)) + _check_network_with_micro_size('GradAccumulationCellWithMultiOutputs', network, micro_size) for i in range(micro_size): micro_input = _MicroBatch(micro_size) micro_input.strided_slice.add_prim_attr("grad_accu_num", micro_size) @@ -658,15 +679,7 @@ class PipelineCellWithTwoOutput(nn.Cell): self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] - if not isinstance(network, nn.Cell): - raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'network' must cell type, " - "but got the type : {}.".format(type(network))) - if not isinstance(micro_size, int): - raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be integer, " - "but got the type : {}.".format(type(micro_size))) - if micro_size <= 0: - raise ValueError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be large than 0, " - "but got {}.".format(micro_size)) + _check_network_with_micro_size('PipelineCellWithTwoOutput', network, micro_size) for i in range(micro_size): micro_input = _MicroBatch(micro_size) self.micro_inputs.append(micro_input) @@ -724,15 +737,7 @@ class PipelineCellWithMultiOutputs(nn.Cell): self.micro_inputs = nn.CellList() self.micro_size = micro_size self.add_list = [] - if not isinstance(network, nn.Cell): - raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'network' must cell type, " - "but got the type : {}.".format(type(network))) - if not isinstance(micro_size, int): - raise TypeError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be integer, " - "but got the type : {}.".format(type(micro_size))) - if micro_size <= 0: - raise ValueError("For 'PipelineCellWithTwoOutput', the argument 'micro_size' must be large than 0, " - "but got {}.".format(micro_size)) + _check_network_with_micro_size('PipelineCellWithMultiOutputs', network, micro_size) for i in range(micro_size): micro_input = _MicroBatch(micro_size) self.micro_inputs.append(micro_input) @@ -870,7 +875,7 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): use_skip_data_by_global_norm=False, print_separate_loss=False, **kwargs): if isinstance(scale_sense, (int, float)): scale_sense = Tensor(scale_sense) - super(MFPipelineWithLossScaleCell, self).__init__(network, optimizer, scale_sense) + super().__init__(network, optimizer, scale_sense) self.network = network self.network.add_flags(defer_inline=True) self.weights = optimizer.parameters @@ -897,13 +902,12 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32), name="scale_sense") elif isinstance(scale_sense, Tensor): - if scale_sense.shape == (1,) or scale_sense.shape == (): + if scale_sense.shape in ((1, ), ()): self.scale_sense = Parameter(scale_sense, name='scale_sense') else: - raise ValueError("The shape of 'scale_sense' must be (1,) or (), but got {}" - .format(scale_sense.shape)) + raise ValueError(f"The shape of 'scale_sense' must be (1,) or (), but got {scale_sense.shape}") else: - raise TypeError("The 'scale_sense' must be Cell or Tensor, but got {}".format(type(scale_sense))) + raise TypeError(f"The 'scale_sense' must be Cell or Tensor, but got {type(scale_sense)}") self.opt_shard = _get_enable_parallel_optimizer() self.use_clip_grad = use_clip_grad self.clip_grad_norm = ClipGradNorm(max_norm=max_grad_norm) @@ -953,6 +957,20 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): transformer_config.num_layers ) + # Get grouped LR schedulers from base_trainer + grouped_lr_scheduler = kwargs.get('grouped_lr_scheduler') + self.use_grouped_lr = grouped_lr_scheduler is not None + if self.use_grouped_lr: + # Register default_lr + self.default_lr = kwargs.get('lr_scheduler') + self.default_lr_parameter = parameter_register.register( + "current_default_lr", Tensor(0., mstype.float32)) + # Register grouped_lr + self.grouped_lr = nn.CellList( + [deepcopy(current_lr.get('lr_scheduler')) for current_lr in grouped_lr_scheduler]) + self.grouped_lr_parameter = parameter_register.register( + "current_grouped_lr", Tensor([0.] * len(self.grouped_lr), mstype.float32)) + @C.add_flags(has_effect=True) def construct(self, *inputs): """The construct processes of pipeline wrapper cell.""" @@ -986,7 +1004,13 @@ class MFPipelineWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): learning_rate = self.learning_rate if self.optimizer.dynamic_lr: - if self.optimizer.is_group_lr: + if self.use_grouped_lr: + # Get default_lr, grouped_lr and update parameters registered + learning_rate = self.default_lr(self.optimizer.global_step).reshape(()) + self.default_lr_parameter = learning_rate + for group_id, current_lr in enumerate(self.grouped_lr): + self.grouped_lr_parameter[group_id] = current_lr(self.optimizer.global_step).reshape(()) + elif self.optimizer.is_group_lr: learning_rate = self.learning_rate[-1](self.optimizer.global_step).reshape(()) else: learning_rate = self.learning_rate(self.optimizer.global_step).reshape(()) diff --git a/tests/st/test_ut/test_optimizer_grouped_parameters.py b/tests/st/test_ut/test_optimizer_grouped_parameters.py new file mode 100644 index 000000000..e0d1e9279 --- /dev/null +++ b/tests/st/test_ut/test_optimizer_grouped_parameters.py @@ -0,0 +1,149 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test get_optimizer_grouped_parameters api.""" + +import pytest +import numpy as np + +import mindspore as ms +from mindspore import Tensor, Parameter, nn + +from mindformers.trainer.optimizer_grouped_parameters import get_optimizer_grouped_parameters +from mindformers.core.lr.lr_schedule import LinearWithWarmUpLR +from mindformers.tools.register.config import MindFormerConfig + + +class Bias(nn.Cell): + """ A simple bias module for test. """ + + def __init__(self): + super().__init__(auto_prefix=True) + self.bias = Parameter(Tensor([0.1], ms.int32), name="bias", requires_grad=True) + + def construct(self, x): + return x + self.bias + + +class Net(nn.Cell): + """ A simple net for test. """ + + def __init__(self): + super().__init__(auto_prefix=True) + self.weight = Parameter(Tensor(np.random.rand(128, 512), ms.float32), name="weight", requires_grad=True) + self.value = Parameter(Tensor([2], ms.int32), name="value", requires_grad=True) + self.model = Bias() + + def construct(self, x): + x = x * self.weight * self.value + output = self.model(x) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_get_grouped_params(): + """ + Feature: get_optimizer_grouped_parameters api + Description: Test get_optimizer_grouped_parameters function + Expectation: No exception. + """ + + model = Net() + weight_decay = 0.01 + dynamic_lr_schedule = LinearWithWarmUpLR( + learning_rate=0.001, + total_steps=100, + warmup_steps=0, + warmup_lr_init=0.0, + warmup_ratio=None + ) + + grouped_params = get_optimizer_grouped_parameters( + model=model, + weight_decay=weight_decay, + dynamic_lr_schedule=dynamic_lr_schedule, + layer_scale=False, + layer_decay=1.0, + # use for ("PmaAdamW", "FusedPmaAdamW") + optimizer_type='AdamW', + model_params=None + ) + + target_dict = [ + {'weight_decay': 0.01, 'params': [model.weight]}, + {'weight_decay': 0.0, 'params': [model.value, model.model.bias]}, + ] + assert grouped_params == target_dict, f"Get params {grouped_params}, but should be {target_dict}." + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_get_grouped_params_with_grouped_lr(): + """ + Feature: get_optimizer_grouped_parameters api + Description: Test get_optimizer_grouped_parameters function with grouped lr scheduler + Expectation: No exception. + """ + + model = Net() + weight_decay = 0.01 + dynamic_lr_schedule = LinearWithWarmUpLR( + learning_rate=0.001, + total_steps=100, + warmup_steps=0, + warmup_lr_init=0.0, + warmup_ratio=None + ) + + lr_config = MindFormerConfig(**{ + 'type':'LinearWithWarmUpLR', + 'params': ['value*'], + 'learning_rate': 1.e-6, + 'warmup_steps': 0, + 'total_steps': -1 + }) + lr_scheduler = LinearWithWarmUpLR( + learning_rate=1.e-6, + total_steps=100, + warmup_steps=0, + warmup_lr_init=0.0, + warmup_ratio=None + ) + grouped_lr_schedule = [{ + 'params': lr_config.params, + 'lr_scheduler': lr_scheduler, + 'lr_config': lr_config + }] + + grouped_params = get_optimizer_grouped_parameters( + model=model, + weight_decay=weight_decay, + dynamic_lr_schedule=dynamic_lr_schedule, + layer_scale=False, + layer_decay=1.0, + # use for ("PmaAdamW", "FusedPmaAdamW") + optimizer_type='AdamW', + model_params=None, + grouped_lr_schedule=grouped_lr_schedule, + ) + + target_dict = [ + {'weight_decay': 0.01, 'params': [model.weight]}, + {'weight_decay': 0.0, 'params': [model.value], 'lr': lr_scheduler}, + {'weight_decay': 0.0, 'params': [model.model.bias]} + ] + assert grouped_params == target_dict, f"Get params {grouped_params}, but should be {target_dict}." -- Gitee