diff --git a/mindformers/core/callback_pynative/__init__.py b/mindformers/core/callback_pynative/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb7fabbe73a78a8d9098578f960b948f0698f771 --- /dev/null +++ b/mindformers/core/callback_pynative/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Callback module for MindFormers Trainer.""" + +from .callback import TrainerCallback, CallbackHandler +from .loss_callback import LossCallback +from .checkpoint_callback import CheckpointCallback + +__all__ = [ + 'TrainerCallback', + 'CallbackHandler', + 'LossCallback', + 'CheckpointCallback' +] diff --git a/mindformers/core/callback_pynative/callback.py b/mindformers/core/callback_pynative/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..88aa6502a496ef378cf3fdbd3d6c182eed0d4948 --- /dev/null +++ b/mindformers/core/callback_pynative/callback.py @@ -0,0 +1,353 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Callback base classes and handler for Trainer.""" +import abc +from typing import List, Optional + +from mindformers.tools.logger import logger + + +class TrainerCallback(metaclass=abc.ABCMeta): + """ + Base class for callbacks that can be registered with the Trainer. + + A callback can execute custom code at various points during training, + including at the beginning/end of training, epoch, and step. + + All callback methods receive the following parameters: + args: Training arguments + state: Current training state + **kwargs: Additional keyword arguments including model, optimizer, etc. + """ + + def on_begin(self, args, state, **kwargs): + """ + Event called at the beginning of a task. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_end(self, args, state, **kwargs): + """ + Event called at the end of a task. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_train_begin(self, args, state, **kwargs): + """ + Event called at the beginning of training. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_train_end(self, args, state, **kwargs): + """ + Event called at the end of training. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_epoch_begin(self, args, state, **kwargs): + """ + Event called at the beginning of an epoch. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_epoch_end(self, args, state, **kwargs): + """ + Event called at the end of an epoch. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_step_begin(self, args, state, **kwargs): + """ + Event called at the beginning of a training step. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + def on_step_end(self, args, state, **kwargs): + """ + Event called at the end of a training step. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + + +class CallbackHandler: + """ + Internal class that manages and calls all registered callbacks. + + This class is responsible for: + - Managing the list of callbacks + - Adding/removing callbacks + - Calling all callbacks at appropriate events + + Args: + callbacks (List[TrainerCallback], optional): + List of callbacks to register initially + model: The model being trained + train_dataset: Training dataset + eval_dataset: Evaluation dataset + optimizer: Optimizer instance + lr_scheduler: Learning rate scheduler instance + """ + + def __init__( + self, + callbacks: Optional[List[TrainerCallback]] = None, + model=None, + train_dataset=None, + eval_dataset=None, + optimizer=None, + lr_scheduler=None + ): + """ + Initialize the callback handler. + + Args: + callbacks: List of TrainerCallback instances + model: Model instance + train_dataset: Training dataset instance + eval_dataset: Evaluation dataset instance + optimizer: Optimizer instance + lr_scheduler: Learning rate scheduler instance + """ + self.callbacks = [] + if callbacks is not None: + for cb in callbacks: + self.add_callback(cb) + + self.model = model + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + def add_callback(self, callback: TrainerCallback): + """ + Add a callback to the handler. + + Args: + callback: TrainerCallback instance or class to add + """ + # If callback is a class, instantiate it + cb = callback() if isinstance(callback, type) else callback + cb_class = callback if isinstance(callback, type) else callback.__class__ + + # Check if this type of callback already exists + existing_callbacks = [c.__class__ for c in self.callbacks] + if cb_class in existing_callbacks: + logger.warning( + f"You are adding a {cb_class.__name__} to the callbacks of this Trainer, " + f"but there is already one. The current list of callbacks is: " + f"{[c.__class__.__name__ for c in self.callbacks]}" + ) + + self.callbacks.append(cb) + + def pop_callback(self, callback): + """ + Remove and return a callback from the handler. + + Args: + callback: TrainerCallback instance or class to remove + + Returns: + The removed callback instance, or None if not found + """ + if isinstance(callback, type): + # callback is a class, find instance of that class + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + # callback is an instance + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + return None + + def remove_callback(self, callback): + """ + Remove a callback from the handler without returning it. + + Args: + callback: TrainerCallback instance or class to remove + """ + if isinstance(callback, type): + # callback is a class, remove all instances of that class + for cb in self.callbacks[:]: # Copy list to avoid modification during iteration + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + # callback is an instance + if callback in self.callbacks: + self.callbacks.remove(callback) + + def on_begin(self, args, state, **kwargs): + """ + Call on_begin for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_begin", args, state, **kwargs) + + def on_end(self, args, state, **kwargs): + """ + Call on_end for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_end", args, state, **kwargs) + + def on_train_begin(self, args, state, **kwargs): + """ + Call on_train_begin for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_train_begin", args, state, **kwargs) + + def on_train_end(self, args, state, **kwargs): + """ + Call on_train_end for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_train_end", args, state, **kwargs) + + def on_epoch_begin(self, args, state, **kwargs): + """ + Call on_epoch_begin for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_epoch_begin", args, state, **kwargs) + + def on_epoch_end(self, args, state, **kwargs): + """ + Call on_epoch_end for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_epoch_end", args, state, **kwargs) + + def on_step_begin(self, args, state, **kwargs): + """ + Call on_step_begin for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_step_begin", args, state, **kwargs) + + def on_step_end(self, args, state, **kwargs): + """ + Call on_step_end for all registered callbacks. + + Args: + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + """ + return self.call_event("on_step_end", args, state, **kwargs) + + def call_event(self, event: str, args, state, **kwargs): + """ + Call a specific event on all registered callbacks. + + Args: + event: Name of the event method to call + args: Training arguments + state: Current trainer state + **kwargs: Additional keyword arguments + + Returns: + Result from the last callback (if any) + """ + result = None + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + **kwargs, + ) + return result + + @property + def callback_list(self) -> str: + """ + Get a string representation of all registered callbacks. + + Returns: + String listing all callback class names + """ + return "\n".join([cb.__class__.__name__ for cb in self.callbacks]) diff --git a/mindformers/core/callback_pynative/checkpoint_callback.py b/mindformers/core/callback_pynative/checkpoint_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..17057d0935c1aca1f8136e0a27d9eedc5e8afa6c --- /dev/null +++ b/mindformers/core/callback_pynative/checkpoint_callback.py @@ -0,0 +1,210 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Checkpoint callback for saving model checkpoints during training.""" +import os +from mindformers.core.callback_pynative.callback import TrainerCallback +from mindformers.tools.logger import logger +from mindformers.checkpoint import save_checkpoint +from mindformers.checkpoint.checkpoint import CommonInfo, AsyncSaveManager + + +class CheckpointCallback(TrainerCallback): + """ + Callback for saving model checkpoints during training. + + This callback saves model checkpoints at specified intervals and at the end of training. + It can save both the model parameters and optimizer state. + + Args: + save_dir (str): Directory where checkpoints will be saved + save_interval (int): Number of steps between checkpoint saves. Default: 1000 + save_optimizer (bool): Whether to save optimizer state. Default: True + keep_checkpoint_max (int): Maximum number of checkpoints to keep. Default: 5 + save_on_train_end (bool): Whether to save checkpoint at the end of training. Default: True + user_prefix (str): Prefix for checkpoint file names. Default: "checkpoint" + async_save (bool): Enable async save. Default: False + remove_redundancy (bool): Whether to remove redundancy when saving. Default: False + """ + + def __init__( + self, + save_dir: str, + save_interval: int = 1000, + save_optimizer: bool = True, + keep_checkpoint_max: int = 5, + save_on_train_end: bool = True, + user_prefix: str = "checkpoint", + async_save: bool = False, + remove_redundancy: bool = False + ): + """ + Initialize the CheckpointCallback. + + Args: + save_dir: Directory path for saving checkpoints + save_interval: Steps between checkpoint saves + save_optimizer: Whether to save optimizer state + keep_checkpoint_max: Maximum number of checkpoints to keep + save_on_train_end: Whether to save at training end + user_prefix: Prefix for checkpoint file names + async_save: Enable async save + remove_redundancy: Whether to remove redundancy when saving + """ + super().__init__() + self.save_dir = save_dir + self.save_interval = save_interval + self.save_optimizer = save_optimizer + self.keep_checkpoint_max = keep_checkpoint_max + self.save_on_train_end = save_on_train_end + self.user_prefix = user_prefix + self.async_save = async_save + self.remove_redundancy = remove_redundancy + + # Create async save manager if needed + self.async_save_manager = None + if self.async_save: + self.async_save_manager = AsyncSaveManager(async_save=True) + logger.info("AsyncSaveManager created") + + # pylint: disable=unused-argument + def on_train_begin(self, args, state, **kwargs): + """ + Called at the beginning of training. + + Creates the save directory if it doesn't exist. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir, exist_ok=True) + logger.info(f"Created checkpoint directory: {self.save_dir}") + + def on_step_end(self, args, state, **kwargs): + """ + Called at the end of each training step. + + Saves checkpoint if the current step matches the save interval. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments including: + - model: The model to save + - optimizer: The optimizer to save (if save_optimizer=True) + """ + # Check if we should save at this step + if state.global_step % self.save_interval != 0: + return + + self._save_checkpoint(args, state, **kwargs) + + def on_train_end(self, args, state, **kwargs): + """ + Called at the end of training. + + Saves a final checkpoint if save_on_train_end is True. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + if self.save_on_train_end: + self._save_checkpoint(args, state, is_final=True, **kwargs) + logger.info("Training completed. Final checkpoint saved.") + + # pylint: disable=unused-argument + def _save_checkpoint(self, args, state, is_final=False, **kwargs): + """ + Save a checkpoint using mindformers.checkpoint.save_checkpoint. + + Args: + args: Training arguments + state: Trainer state + is_final: Whether this is the final checkpoint + **kwargs: Additional keyword arguments including model and optimizer + """ + model = kwargs.get("model", None) + optimizer = kwargs.get("optimizer", None) + + if model is None: + logger.warning("No model provided to CheckpointCallback, skipping save.") + return + + # Create CommonInfo from TrainerState (always required) + common_info = self._create_common_info(state) + + try: + # Prepare async save manager if needed + if self.async_save_manager is not None: + self.async_save_manager.prepare_before_save() + + # Call mindformers save_checkpoint with full parameters + save_checkpoint( + iteration=state.global_step, + network=model, + optimizer=optimizer if self.save_optimizer else None, + async_save_manager=self.async_save_manager, + common_info=common_info, + keep_max_num=self.keep_checkpoint_max, + user_prefix=self.user_prefix, + save_checkpoint_path=self.save_dir, + remove_redundancy=self.remove_redundancy + ) + + logger.info( + f"Checkpoint saved at step {state.global_step} to {self.save_dir} " + f"(async={self.async_save}, remove_redundancy={self.remove_redundancy})" + ) + + except Exception as e: + logger.error(f"Error saving checkpoint: {e}") + + def _create_common_info(self, state) -> CommonInfo: + """ + Create CommonInfo from TrainerState. + + Args: + state: Trainer state containing training information + + Returns: + CommonInfo instance + """ + common_info = CommonInfo() + + # Extract information from state + if hasattr(state, 'epoch'): + common_info.epoch_num = int(state.epoch) + + if hasattr(state, 'global_step'): + common_info.global_step = state.global_step + + if hasattr(state, 'epoch_step') and state.epoch_step > 0: + # Calculate step_num within current epoch + common_info.step_num = state.global_step % state.epoch_step + + # Try to get batch size if available + if hasattr(state, 'global_batch_size'): + common_info.global_batch_size = state.global_batch_size + + logger.debug( + f"Created CommonInfo: epoch={common_info.epoch_num}, " + f"step={common_info.step_num}, global_step={common_info.global_step}" + ) + + return common_info diff --git a/mindformers/core/callback_pynative/loss_callback.py b/mindformers/core/callback_pynative/loss_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3bcc509bd6e0b4e97eabb3a476871c5d220a99 --- /dev/null +++ b/mindformers/core/callback_pynative/loss_callback.py @@ -0,0 +1,184 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Loss callback for logging training loss.""" +import time +from typing import Dict, Any +from mindformers.core.callback_pynative.callback import TrainerCallback +from mindformers.tools.logger import logger + + +class LossCallback(TrainerCallback): + """ + Callback for logging loss information during training. + + This callback logs the training loss at the end of each training step. + + Compared to MFLossMonitor: + - MFLossMonitor supports: pipeline parallel loss fixing, MoE/MTP separate loss, + model FLOPs calculation, throughput computation, time remaining estimation, + overflow/scaling_sens monitoring, global_norm logging, TensorBoard integration + - LossCallback: Simplified version for basic loss logging only + + Args: + log_interval (int): Number of steps between loss logging. Default: 1 + + TODOs: + - TODO: Support overflow and scaling_sens monitoring (similar to MFLossMonitor) + - TODO: Support global_norm logging + - TODO: Support MoE/MTP separate loss printing + - TODO: Support throughput calculation + - TODO: Support model FLOPs calculation + - TODO: Support TensorBoard integration + - TODO: Support pipeline parallel loss fixing + """ + + def __init__(self, log_interval: int = 1): + """ + Initialize the LossCallback. + + Args: + log_interval: How often to log loss (in steps) + """ + super().__init__() + self.log_interval = log_interval + self.step_time = time.time() + self.epoch_time = time.time() + + # pylint: disable=unused-argument + def on_train_begin(self, args, state, **kwargs): + """ + Called at the beginning of training. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + self.step_time = time.time() + self.epoch_time = time.time() + + # pylint: disable=unused-argument + def on_epoch_begin(self, args, state, **kwargs): + """ + Called at the beginning of each epoch. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + self.epoch_time = time.time() + + # pylint: disable=unused-argument + def on_step_begin(self, args, state, **kwargs): + """ + Called at the beginning of each training step. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + self.step_time = time.time() + + # pylint: disable=unused-argument + def on_step_end(self, args, state, **kwargs): + """ + Called at the end of each training step. + + Logs the loss value and optionally computes statistics. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments including: + - loss: The current step loss value + """ + loss = kwargs.get("loss", None) + + if loss is None: + return + + # Convert loss to float if it's a tensor + if hasattr(loss, "asnumpy"): + loss_value = float(loss.asnumpy()) + elif hasattr(loss, "item"): + loss_value = loss.item() + else: + loss_value = float(loss) + + # Log loss at specified intervals + if state.global_step % self.log_interval == 0: + cur_time = time.time() + step_time_cost = (cur_time - self.step_time) * 1000 # Convert to milliseconds + + # Prepare log information + log_info = { + "loss": loss_value, + "cur_step": state.global_step, + "max_steps": state.max_steps, + "step_time": step_time_cost, + } + + # Implement later: Extract learning rate from optimizer or lr_scheduler + # Similar to MFLossMonitor's approach: + # - Get from kwargs['lr_scheduler'] if available + # - Or get from args.optimizer.global_step + learning_rate_schedule + + # Print log information + self._print_log(log_info) + + # pylint: disable=unused-argument + def on_epoch_end(self, args, state, **kwargs): + """ + Called at the end of each epoch. + + Args: + args: Training arguments + state: Trainer state + **kwargs: Additional keyword arguments + """ + epoch_time = time.time() - self.epoch_time + logger.info(f"Epoch {state.epoch} finished. Time: {epoch_time:.2f}s") + + def _print_log(self, log_info: Dict[str, Any]): + """ + Print formatted log information in MFLossMonitor style. + + Format similar to MFLossMonitor: + "step:[cur_step/max_steps], loss: X.XXXXXX, per_step_time: Xms, lr: X.XXXe-XX" + + Args: + log_info: Dictionary containing log information + state: Trainer state for accessing additional info + """ + cur_step = log_info.get('cur_step', 0) + max_steps = log_info.get('max_steps', 0) + loss = log_info.get('loss', 0) + per_step_time = int(log_info.get('step_time', 0)) + + # Build log string in MFLossMonitor format + log_parts = [f"step:[{cur_step:5d}/{max_steps:5d}]"] + log_parts.append(f"loss: {loss:.6f}") + log_parts.append(f"per_step_time: {per_step_time}ms") + + if "learning_rate" in log_info: + lr = log_info["learning_rate"] + if isinstance(lr, (list, tuple)): + lr = lr[0] + log_parts.append(f"lr: {lr:.6e}") + + # Format: "{ step:[X/Y], loss: X.XXXXXX, per_step_time: Xms, lr: X.XXXe-XX }" + logger.info("{ " + ", ".join(log_parts) + " }") diff --git a/mindformers/trainer_pynative/__init__.py b/mindformers/trainer_pynative/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abbc9044cceb3d8a1961599e052c6805267ce164 --- /dev/null +++ b/mindformers/trainer_pynative/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Trainer module for MindFormers.""" + +from .trainer import Trainer, TrainMode +from .train_state import TrainerState + +__all__ = ['Trainer', 'TrainMode', 'TrainerState'] diff --git a/mindformers/trainer_pynative/train_state.py b/mindformers/trainer_pynative/train_state.py new file mode 100644 index 0000000000000000000000000000000000000000..e48813c937f0641a2a41ee11f7ee4a5a332d720d --- /dev/null +++ b/mindformers/trainer_pynative/train_state.py @@ -0,0 +1,105 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""TrainerState for tracking training progress.""" +from dataclasses import dataclass +from typing import Optional, Dict, Any + + +@dataclass +class TrainerState: + """ + A class containing the state of the :class:`~Trainer` during training. + + Args: + epoch (float): + Current epoch number. Can be fractional for partial epochs. + global_step (int): + Current global training step. + max_steps (int): + Total number of training steps to perform. + eval_steps (int): + Number of steps between evaluations. + save_steps (int): + Number of steps between checkpoint saves. + epoch_step (int): + Number of steps in one epoch. Used to determine epoch boundaries. + global_batch_size (int): + Global batch size across all devices. + best_metric (float): + Best metric value achieved so far. + best_model_checkpoint (str): + Path to the best model checkpoint. + is_train_begin (bool): + Whether training has begun. + is_train_end (bool): + Whether training has ended. + """ + + epoch: float = 0.0 + global_step: int = 0 + max_steps: int = 0 + eval_steps: int = 0 + save_steps: int = 0 + epoch_step: int = 0 + global_batch_size: int = 0 + best_metric: Optional[float] = None + best_model_checkpoint: Optional[str] = None + is_train_begin: bool = False + is_train_end: bool = False + + def update_epoch(self): + """Update epoch based on current step and epoch_step.""" + if self.epoch_step > 0: + self.epoch = self.global_step / self.epoch_step + + def save_to_dict(self) -> Dict[str, Any]: + """ + Save the state to a dictionary. + + Returns: + Dict[str, Any]: Dictionary containing the state. + """ + return { + "epoch": self.epoch, + "global_step": self.global_step, + "max_steps": self.max_steps, + "eval_steps": self.eval_steps, + "save_steps": self.save_steps, + "epoch_step": self.epoch_step, + "global_batch_size": self.global_batch_size, + "best_metric": self.best_metric, + "best_model_checkpoint": self.best_model_checkpoint, + } + + @classmethod + def load_from_dict(cls, state_dict: Dict[str, Any]) -> "TrainerState": + """ + Load the state from a dictionary. + + Args: + state_dict (Dict[str, Any]): Dictionary containing the state. + + Returns: + TrainerState: The loaded state object. + """ + return cls(**state_dict) + + def __repr__(self): + """Return string representation of the state.""" + return ( + f"TrainerState(epoch={self.epoch}, " + f"global_step={self.global_step}, " + f"max_steps={self.max_steps})" + ) diff --git a/mindformers/trainer_pynative/trainer.py b/mindformers/trainer_pynative/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ab115d5e4f8be6b912eb0dfdff66ea19ad7c162 --- /dev/null +++ b/mindformers/trainer_pynative/trainer.py @@ -0,0 +1,679 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Trainer for training models with MindFormers.""" +import os +import enum +from typing import Union, Optional, Callable, List, Dict, Any + +from mindspore.dataset import Dataset + +# Import from mindformers +from mindformers.dataset import build_dataset +from mindformers.models import PreTrainedModel, build_network +from mindformers.checkpoint.checkpoint import load_checkpoint +from mindformers.trainer.training_args import TrainingArguments +from mindformers.tools.logger import logger +from mindformers.tools import MindFormerConfig +from mindformers.core import build_lr, build_optim, build_callback, build_metric +from mindformers.pet import get_pet_model +from mindformers.core.callback_pynative import CallbackHandler +from mindformers.trainer.optimizer_grouped_parameters import get_optimizer_grouped_parameters +from mindformers.trainer_pynative.train_state import TrainerState + + +class TrainMode(enum.Enum): + """Training mode enumeration.""" + FINETUNE = "finetune" + PRETRAIN = "pretrain" + + +class Trainer: + """ + Trainer for training models in MindFormers. + + The Trainer class provides a unified interface for training models with support for: + - Model training and evaluation + - Checkpoint management + - Callback system + - Custom loss functions + - Distributed training + + Args: + model: Model instance or None. If None, will be built from config. + config: Either a path to yaml config file or a TrainingArguments instance + compute_loss_func: Optional custom loss function + train_dataset: Training dataset instance or None + eval_dataset: Evaluation dataset instance or None + processing_class: Optional processor for data preprocessing + optimizer: Optimizer instance or None + lr_scheduler: Learning rate scheduler instance or None + compute_metrics: Optional function to compute evaluation metrics + callbacks: List of callback instances + """ + + def __init__( + self, + model: PreTrainedModel = None, + config: Union[str, TrainingArguments] = None, + compute_loss_func: Optional[Callable] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Dataset] = None, + processing_class: Optional[Any] = None, + optimizer: Optional[Any] = None, + lr_scheduler: Optional[Any] = None, + compute_metrics: Optional[Callable] = None, + callbacks: Optional[List] = None + ): + """ + Initialize the Trainer. + + Args: + model: Model instance + config: YAML configuration file path (str) or TrainingArguments instance + compute_loss_func: Custom loss function + train_dataset: Training dataset + eval_dataset: Evaluation dataset + processing_class: Data processor + optimizer: Optimizer instance + lr_scheduler: Learning rate scheduler + compute_metrics: Metrics computation function + callbacks: List of callbacks + """ + # Initialize config + self.config = self._init_config(config) + + # Verify instance validity when config is yaml file + if isinstance(config, str): + if any([model, train_dataset, eval_dataset, optimizer, lr_scheduler, callbacks]): + logger.warning( + "When config is a yaml file, model/dataset/optimizer/lr_scheduler/callbacks " + "instances should not be provided. They will be built from config." + ) + + # Create model + self.model = self._create_model( + model, + getattr(self.config, 'model', None) + ) + + # Create datasets + self.train_dataset = self._create_dataset( + train_dataset, + getattr(self.config, 'train_dataset', None) + ) + self.eval_dataset = self._create_dataset( + eval_dataset, + getattr(self.config, 'eval_dataset', None) + ) + + # Create optimizer and scheduler + self.optimizer, self.lr_scheduler = self._create_optimizer_and_scheduler( + optimizer, + lr_scheduler, + getattr(self.config, 'optimizer', None), + getattr(self.config, 'lr_schedule', None) + ) + + # Create callback handler + self.callback_handler = self._create_callback_handler( + callbacks, + self.config + ) + + # Create metrics + self.compute_metrics = self._create_metrics( + compute_metrics, + getattr(self.config, 'metric', None) + ) + + # Store other parameters + self.compute_loss_func = compute_loss_func + self.processing_class = processing_class + + # Initialize training state + self.state = None + + def _init_config(self, config: Union[str, TrainingArguments]) -> MindFormerConfig: + """ + Initialize trainer config from yaml file or TrainingArguments instance. + + This method converts config inputs to MindFormerConfig: + - yaml file path (str) -> MindFormerConfig + - TrainingArguments -> MindFormerConfig + + Args: + config: Either a yaml file path (str) or TrainingArguments instance + + Returns: + MindFormerConfig instance + + Raises: + ValueError: If config is None + FileNotFoundError: If yaml file does not exist + """ + if config is None: + raise ValueError("config cannot be None. Please provide a yaml file path or TrainingArguments instance.") + + # If config is a string (yaml file path), load it as MindFormerConfig + if isinstance(config, str): + if not os.path.exists(config): + raise FileNotFoundError(f"Config file not found: {config}") + + logger.info(f"Loading config from yaml file: {config}") + return MindFormerConfig(config) + + # If config is TrainingArguments, convert to MindFormerConfig + if isinstance(config, TrainingArguments): + logger.info("Converting TrainingArguments to MindFormerConfig") + # Convert TrainingArguments to dict first + config_dict = {} + for key in dir(config): + if not key.startswith('_') and not callable(getattr(config, key)): + config_dict[key] = getattr(config, key) + return MindFormerConfig(**config_dict) + + # Should not reach here due to type hints + raise TypeError(f"config must be str or TrainingArguments, got {type(config)}") + + def _create_model(self, model, model_config: Optional[Dict]) -> Any: + """ + Create or validate model instance. + + Args: + model: User-provided model instance or None + model_config: Model configuration from yaml + + Returns: + Model instance + """ + # If user provided model instance, use it directly + if model is not None: + logger.info("Using user-provided model instance.") + return model + + # Build model from config + if model_config is None: + raise ValueError("Either model instance or model_config must be provided.") + + logger.info("Building model from config...") + model = build_network(model_config) + + # Apply PET if provided (after base model is built) + pet_config = getattr(self.config, 'pet_config', None) + if pet_config is not None: + logger.info("Applying PET configuration to model...") + model = get_pet_model(model, pet_config) + + return model + + # pylint: disable=unused-argument + def _wrapper_model(self, model, config: Dict) -> Any: + """ + Wrap model for distributed training (HSDP). + + Args: + model: Model to wrap + config: Wrapper configuration (reserved for future use) + + Returns: + Wrapped model + """ + # Reserved interface: currently no wrapper logic required + logger.info("Wrapper is a no-op. Returning model as-is.") + return model + + def _create_dataset( + self, + dataset, + dataset_config: Optional[Dict] + ) -> Optional[Any]: + """ + Create or validate dataset instance. + + Args: + dataset: User-provided dataset instance or None + dataset_config: Dataset configuration from yaml + + Returns: + Dataset instance or None + """ + # If user provided dataset instance, use it directly + if dataset is not None: + logger.info("Using user-provided dataset instance.") + return dataset + + # If no config, return None + if dataset_config is None: + return None + + # Build dataset from config + logger.info("Building dataset from config...") + return build_dataset(dataset_config) + + def _create_optimizer_and_scheduler( + self, + optimizer, + lr_scheduler, + optimizer_config: Optional[Dict], + lr_config: Optional[Dict] + ) -> tuple: + """ + Create optimizer and learning rate scheduler. + + Args: + optimizer: User-provided optimizer instance or None + lr_scheduler: User-provided LR scheduler instance or None + optimizer_config: Optimizer configuration from yaml + lr_config: LR scheduler configuration from yaml + + Returns: + Tuple of (optimizer, lr_scheduler) + """ + # If user provided instances, use them directly + if optimizer is not None and lr_scheduler is not None: + logger.info("Using user-provided optimizer and lr_scheduler instances.") + return optimizer, lr_scheduler + + # Build from config + if optimizer_config is None or lr_config is None: + logger.warning("No optimizer or lr_scheduler config provided.") + return None, None + + logger.info("Building optimizer and lr_scheduler from config...") + + # Build learning rate scheduler first + lr = build_lr(lr_config) + + # Get grouped parameters using official utility + weight_decay = getattr(optimizer_config, 'weight_decay', 0.0) if optimizer_config else 0.0 + grouped_params = get_optimizer_grouped_parameters( + model=self.model, + weight_decay=weight_decay, + dynamic_lr_schedule=lr, + layer_scale=False, + layer_decay=1.0, + optimizer_type=getattr(optimizer_config, 'type', 'AdamW') if optimizer_config else 'AdamW', + model_params=None + ) + + # Build optimizer using default_args to inject params and lr + default_args = { + 'params': grouped_params, + 'learning_rate': lr + } + optimizer = build_optim(optimizer_config, default_args=default_args) + + return optimizer, lr + + def _create_callback_handler( + self, + callbacks: Optional[List], + config: Any + ) -> CallbackHandler: + """ + Create callback handler. + + Args: + callbacks: User-provided callback list or None + config: Configuration object + + Returns: + CallbackHandler instance + """ + # Prepare initial callback list + callback_list: List = [] + if callbacks: + callback_list.extend(callbacks) + + # Build callbacks from config and extend + callback_config = getattr(config, 'callbacks', None) + if callback_config is not None: + cbs = build_callback(callback_config) + if cbs: + callback_list.extend(cbs) + + # Create handler with complete list + cb_handler = CallbackHandler( + callbacks=callback_list, + model=self.model, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler + ) + + return cb_handler + + def _create_metrics( + self, + compute_metrics: Optional[Callable], + metric_config: Optional[Dict] + ) -> Optional[Callable]: + """ + Create or validate metrics function. + + Args: + compute_metrics: User-provided metrics function or None + metric_config: Metrics configuration from yaml + + Returns: + Metrics function or None + """ + # If user provided metrics function, use it directly + if compute_metrics is not None: + logger.info("Using user-provided compute_metrics function.") + return compute_metrics + + # Build from config + if metric_config is None: + return None + + logger.info("Building metrics from config...") + return build_metric(metric_config) + + def train( + self, + checkpoint_path: Optional[str] = None, + mode: str = "pretrain", + do_eval: bool = False + ): + """ + Execute the training loop. + + Args: + checkpoint_path: Path to checkpoint file to load + mode: Training mode, either "pretrain" or "finetune" + do_eval: Whether to run evaluation + + Returns: + Training output/results + """ + # Validate mode + if mode not in ["pretrain", "finetune"]: + raise ValueError(f"mode must be 'pretrain' or 'finetune', got: {mode}") + + # Check checkpoint_path + if mode == "finetune" and checkpoint_path is None: + load_checkpoint_path = getattr(self.config, 'load_checkpoint', None) + if load_checkpoint_path is None: + raise ValueError( + "In finetune mode, checkpoint_path cannot be None. " + "Please provide checkpoint_path or set config.load_checkpoint" + ) + checkpoint_path = load_checkpoint_path + elif checkpoint_path is None and hasattr(self.config, 'load_checkpoint'): + checkpoint_path = self.config.load_checkpoint + + # Initialize parallel config and wrappers + self._init_parallel_config() + + # Load checkpoint + if checkpoint_path is not None: + self._load_checkpoint(checkpoint_path, mode) + + # Initialize training state + self.state = TrainerState( + max_steps=getattr(self.config, 'max_steps', 1000), + eval_steps=getattr(self.config, 'eval_steps', 100), + save_steps=getattr(self.config, 'save_steps', 100), + global_batch_size=getattr(self.config, 'global_batch_size', 0), + ) + + # Calculate epoch step + if self.train_dataset is not None: + self.state.epoch_step = self._get_dataset_size(self.train_dataset) + + # Call train begin callback + if self.callback_handler is not None: + self.callback_handler.on_train_begin(self.config, self.state) + + # Execute training loop + self._inner_train_loop(do_eval) + + # Call train end callback + if self.callback_handler is not None: + self.callback_handler.on_train_end(self.config, self.state) + + def _init_parallel_config(self): + """Initialize parallel configuration.""" + # Initialize parallel configuration + # 1) HSDP wrapper + # 2) Pipeline parallel config + # 3) Data parallel config + logger.info("Initializing parallel config...") + + def _load_checkpoint(self, checkpoint_path: str, mode: str): + """ + Load checkpoint from file. + + Args: + checkpoint_path: Path to checkpoint file + mode: Training mode ("pretrain" or "finetune") + """ + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") + + logger.info(f"Loading checkpoint from: {checkpoint_path}") + + # Prepare global_step possibly adjusted by global_batch_size differences later + global_step = getattr(self.state, 'global_step', None) if hasattr(self, 'state') else None + + # balanced_load flag from config if available + balanced_load = getattr(self.config, 'balanced_load', False) + + # Use updated API signature per spec + load_checkpoint( + checkpoint=checkpoint_path, + network=self.model, + optimizer=self.optimizer if mode == "pretrain" else None, + global_step=global_step, + balanced_load=balanced_load, + ) + + # Implement: 需要将commoninfo的参数设置给lr模块、数据集模块等,请根据BaseTrainer的逻辑完善该部分逻辑的修改 + + def _get_dataset_size(self, dataset) -> int: + """ + Get the size of a dataset. + + Args: + dataset: Dataset instance + + Returns: + Number of batches in the dataset + """ + if hasattr(dataset, '__len__'): + return len(dataset) + if hasattr(dataset, 'get_dataset_size'): + return dataset.get_dataset_size() + # Cannot determine dataset size; raise error per spec + raise ValueError("Unable to determine dataset size from the provided dataset.") + + def _inner_train_loop(self, do_eval: bool = False): + """ + Internal training loop. + + Args: + do_eval: Whether to run evaluation + """ + if self.train_dataset is None: + raise ValueError("train_dataset is None, cannot train.") + + # Create dataset iterator + dataset_iter = self._create_dataset_iterator(self.train_dataset) + + # Training loop + step = self.state.global_step + while step < self.state.max_steps: + # Check epoch begin + if step % self.state.epoch_step == 0 and step > 0: + if self.callback_handler is not None: + self.callback_handler.on_epoch_begin(self.config, self.state) + self.state.update_epoch() + + # Get batch data + try: + inputs = self.get_batch(dataset_iter) + except StopIteration: + # Recreate iterator if dataset exhausted + dataset_iter = self._create_dataset_iterator(self.train_dataset) + inputs = self.get_batch(dataset_iter) + + # Step begin callback + if self.callback_handler is not None: + self.callback_handler.on_step_begin(self.config, self.state) + + # Training step + try: + loss = self.training_step(self.model, inputs) + except Exception as e: + logger.error(f"Error in training step {step}: {e}") + raise + + # Update state + self.state.global_step += 1 + step = self.state.global_step + + # Step end callback (pass loss) + if self.callback_handler is not None: + self.callback_handler.on_step_end( + self.config, + self.state, + loss=loss + ) + + # Check epoch end + if step % self.state.epoch_step == 0: + if self.callback_handler is not None: + self.callback_handler.on_epoch_end(self.config, self.state) + + # Evaluation + if do_eval and self.state.eval_steps > 0 and step % self.state.eval_steps == 0: + self.evaluate() + + def _create_dataset_iterator(self, dataset): + """ + Create an iterator for the dataset using MindSpore Dataset API. + + Args: + dataset: mindspore.dataset.Dataset instance + + Returns: + Dictionary iterator from MindSpore dataset + """ + if hasattr(dataset, 'create_dict_iterator'): + return dataset.create_dict_iterator() + raise TypeError(f"Dataset type {type(dataset)} does not support create_dict_iterator()") + + def get_batch(self, dataset_iter) -> Dict[str, Any]: + """ + Get a batch of data from the dataset. + + Modes: + - Distributed dataset mode + - Remove redundant load mode + - Naive loading mode + """ + use_distribute_dataset = getattr(self.config, 'use_distribute_dataset', False) + use_remove_redundant_dataset = getattr(self.config, 'use_remove_redundant_dataset', False) + + if use_distribute_dataset: + data = self._get_batch_distributed(dataset_iter) + elif use_remove_redundant_dataset: + data = self._get_batch_remove_redundant(dataset_iter) + else: + data = self._get_batch_naive(dataset_iter) + + # Ensure dict output + if data is not None and not isinstance(data, dict): + if isinstance(data, (tuple, list)): + data = {"input_ids": data[0]} if len(data) > 0 else {} + return data if data is not None else {} + + def _get_batch_distributed(self, dataset_iter): + """Fetch next batch in distributed dataset mode (simplified).""" + return next(dataset_iter) + + def _get_batch_remove_redundant(self, dataset_iter): + """Fetch next batch in remove-redundant mode (simplified).""" + return next(dataset_iter) + + def _get_batch_naive(self, dataset_iter): + """Fetch next batch in naive loading mode (simplified).""" + return next(dataset_iter) + + def compute_loss( + self, + model, + inputs: Dict[str, Any] + ): + """ + Compute loss for the model. + + Args: + model: The model + inputs: Input data dictionary + + Returns: + Loss value + """ + # Forward pass + outputs = model(**inputs) + + # Get labels from inputs + labels = inputs.get('labels', None) + + # Compute loss + if self.compute_loss_func is not None: + # Use user-defined loss function + loss = self.compute_loss_func(outputs, labels) + else: + # Extract loss from model output + # We don't use .loss here since the model may return tuples instead of ModelOutput + if isinstance(outputs, dict): + loss = outputs["loss"] + else: + # Assume first element is loss + loss = outputs[0] + + return loss + + def training_step( + self, + model, + inputs: Dict[str, Any] + ): + """ + Perform a single training step. + + Args: + model: The model + inputs: Input data dictionary + + Returns: + Loss value + """ + # Forward and compute loss + loss = self.compute_loss(model, inputs) + + # Backward pass + # In real implementation with MindSpore: + + # Optimizer step + + return loss + + def evaluate(self): + """Placeholder for evaluation; to be implemented.""" diff --git a/tests/st/test_ut/test_callback/__init__.py b/tests/st/test_ut/test_callback/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aac198b55d5c6910a63d287085319c6a6fe0498c --- /dev/null +++ b/tests/st/test_ut/test_callback/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for callback module.""" diff --git a/tests/st/test_ut/test_callback/test_callback.py b/tests/st/test_ut/test_callback/test_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..5ac39a7ddf5f946212b91eb4c1d2dc02bb210e5d --- /dev/null +++ b/tests/st/test_ut/test_callback/test_callback.py @@ -0,0 +1,298 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for callback module using pytest.""" +import os +import sys +from unittest.mock import Mock + +import pytest + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) + +from mindformers.core.callback_pynative import TrainerCallback, CallbackHandler # pylint: disable=wrong-import-position + + +class TestTrainerCallback: + """Test cases for TrainerCallback base class.""" + + @pytest.fixture + def callback(self): + """Create callback instance.""" + return TrainerCallback() + + @pytest.fixture + def mock_args(self): + """Create mock args.""" + return Mock() + + @pytest.fixture + def mock_state(self): + """Create mock state.""" + return Mock() + + def test_on_begin(self, callback, mock_args, mock_state): + """Test on_begin method.""" + # Should not raise error (base implementation does nothing) + callback.on_begin(mock_args, mock_state) + + def test_on_end(self, callback, mock_args, mock_state): + """Test on_end method.""" + callback.on_end(mock_args, mock_state) + + def test_on_train_begin(self, callback, mock_args, mock_state): + """Test on_train_begin method.""" + callback.on_train_begin(mock_args, mock_state) + + def test_on_train_end(self, callback, mock_args, mock_state): + """Test on_train_end method.""" + callback.on_train_end(mock_args, mock_state) + + def test_on_epoch_begin(self, callback, mock_args, mock_state): + """Test on_epoch_begin method.""" + callback.on_epoch_begin(mock_args, mock_state) + + def test_on_epoch_end(self, callback, mock_args, mock_state): + """Test on_epoch_end method.""" + callback.on_epoch_end(mock_args, mock_state) + + def test_on_step_begin(self, callback, mock_args, mock_state): + """Test on_step_begin method.""" + callback.on_step_begin(mock_args, mock_state) + + def test_on_step_end(self, callback, mock_args, mock_state): + """Test on_step_end method.""" + callback.on_step_end(mock_args, mock_state) + + +class TestCallbackHandler: + """Test cases for CallbackHandler class.""" + + @pytest.fixture + def model(self): + """Create mock model.""" + return Mock() + + @pytest.fixture + def train_dataset(self): + """Create mock train dataset.""" + return Mock() + + @pytest.fixture + def eval_dataset(self): + """Create mock eval dataset.""" + return Mock() + + @pytest.fixture + def optimizer(self): + """Create mock optimizer.""" + return Mock() + + @pytest.fixture + def lr_scheduler(self): + """Create mock lr scheduler.""" + return Mock() + + def test_init_empty(self): + """Test initialization with no callbacks.""" + handler = CallbackHandler() + assert len(handler.callbacks) == 0 + assert handler.model is None + assert handler.train_dataset is None + + def test_init_with_components(self, model, train_dataset, eval_dataset, optimizer, lr_scheduler): + """Test initialization with model and datasets.""" + handler = CallbackHandler( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + optimizer=optimizer, + lr_scheduler=lr_scheduler + ) + assert handler.model == model + assert handler.train_dataset == train_dataset + assert handler.eval_dataset == eval_dataset + assert handler.optimizer == optimizer + assert handler.lr_scheduler == lr_scheduler + + def test_add_callback_instance(self): + """Test adding callback instance.""" + handler = CallbackHandler() + cb = TrainerCallback() + handler.add_callback(cb) + assert len(handler.callbacks) == 1 + assert handler.callbacks[0] == cb + + def test_add_callback_class(self): + """Test adding callback class (should instantiate).""" + handler = CallbackHandler() + handler.add_callback(TrainerCallback) + assert len(handler.callbacks) == 1 + assert isinstance(handler.callbacks[0], TrainerCallback) + + def test_add_callback_duplicate_warning(self): + """Test warning when adding duplicate callback type.""" + handler = CallbackHandler() + cb1 = TrainerCallback() + cb2 = TrainerCallback() + + handler.add_callback(cb1) + handler.add_callback(cb2) + + # Both should still be added + assert len(handler.callbacks) == 2 + + def test_pop_callback_by_instance(self): + """Test removing callback by instance.""" + handler = CallbackHandler() + cb = TrainerCallback() + handler.add_callback(cb) + + removed = handler.pop_callback(cb) + assert removed == cb + assert len(handler.callbacks) == 0 + + def test_pop_callback_by_class(self): + """Test removing callback by class.""" + handler = CallbackHandler() + cb = TrainerCallback() + handler.add_callback(cb) + + removed = handler.pop_callback(TrainerCallback) + assert isinstance(removed, TrainerCallback) + assert len(handler.callbacks) == 0 + + def test_pop_callback_not_found(self): + """Test popping non-existent callback.""" + handler = CallbackHandler() + removed = handler.pop_callback(TrainerCallback) + assert removed is None + + def test_remove_callback_by_instance(self): + """Test removing callback by instance.""" + handler = CallbackHandler() + cb = TrainerCallback() + handler.add_callback(cb) + + handler.remove_callback(cb) + assert len(handler.callbacks) == 0 + + def test_remove_callback_by_class(self): + """Test removing callback by class.""" + handler = CallbackHandler() + cb = TrainerCallback() + handler.add_callback(cb) + + handler.remove_callback(TrainerCallback) + assert len(handler.callbacks) == 0 + + def test_on_begin_calls_all_callbacks(self): + """Test on_begin calls all registered callbacks.""" + handler = CallbackHandler() + + cb1 = Mock(spec=TrainerCallback) + cb2 = Mock(spec=TrainerCallback) + handler.callbacks = [cb1, cb2] + + args = Mock() + state = Mock() + + handler.on_begin(args, state) + + cb1.on_begin.assert_called_once() + cb2.on_begin.assert_called_once() + + def test_on_train_begin_calls_all_callbacks(self, model, optimizer): + """Test on_train_begin calls all registered callbacks.""" + handler = CallbackHandler( + model=model, + optimizer=optimizer + ) + + cb = Mock(spec=TrainerCallback) + handler.callbacks = [cb] + + args = Mock() + state = Mock() + + handler.on_train_begin(args, state) + + cb.on_train_begin.assert_called_once() + # Check that model and optimizer are passed + call_kwargs = cb.on_train_begin.call_args[1] + assert call_kwargs['model'] == model + assert call_kwargs['optimizer'] == optimizer + + def test_on_step_end_calls_all_callbacks(self): + """Test on_step_end calls all registered callbacks.""" + handler = CallbackHandler() + + cb1 = Mock(spec=TrainerCallback) + cb2 = Mock(spec=TrainerCallback) + handler.callbacks = [cb1, cb2] + + args = Mock() + state = Mock() + + handler.on_step_end(args, state, loss=0.5) + + cb1.on_step_end.assert_called_once() + cb2.on_step_end.assert_called_once() + + # Check loss is passed in kwargs + call_kwargs = cb1.on_step_end.call_args[1] + assert call_kwargs['loss'] == 0.5 + + def test_call_event(self): + """Test call_event method.""" + handler = CallbackHandler() + + cb = Mock(spec=TrainerCallback) + cb.on_train_begin = Mock(return_value="result") + handler.callbacks = [cb] + + args = Mock() + state = Mock() + + result = handler.call_event("on_train_begin", args, state) + + cb.on_train_begin.assert_called_once() + assert result == "result" + + def test_callback_list_property(self): + """Test callback_list property returns string representation.""" + handler = CallbackHandler() + + class CustomCallback1(TrainerCallback): + pass + + class CustomCallback2(TrainerCallback): + pass + + handler.add_callback(CustomCallback1()) + handler.add_callback(CustomCallback2()) + + callback_list = handler.callback_list + assert "CustomCallback1" in callback_list + assert "CustomCallback2" in callback_list + + def test_init_with_callback_list(self): + """Test initialization with callback list.""" + cb1 = TrainerCallback() + cb2 = TrainerCallback() + + handler = CallbackHandler(callbacks=[cb1, cb2]) + + assert len(handler.callbacks) == 2 diff --git a/tests/st/test_ut/test_callback/test_checkpoint_callback.py b/tests/st/test_ut/test_callback/test_checkpoint_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..01c076f175119408006126515941160002f0da8b --- /dev/null +++ b/tests/st/test_ut/test_callback/test_checkpoint_callback.py @@ -0,0 +1,344 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for CheckpointCallback using pytest.""" +import os +import shutil +import sys +import tempfile +from unittest.mock import Mock, patch + +import pytest + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) + +from mindformers.core.callback_pynative import CheckpointCallback # pylint: disable=wrong-import-position + + +class TestCheckpointCallback: + """Test cases for CheckpointCallback.""" + + @pytest.fixture + def temp_dir(self): + """Create temporary directory.""" + temp_dir = tempfile.mkdtemp() + yield temp_dir + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + @pytest.fixture + def save_dir(self, temp_dir): + """Create save directory path.""" + return os.path.join(temp_dir, 'checkpoints') + + @pytest.fixture + def callback(self, save_dir): + """Create callback instance.""" + return CheckpointCallback( + save_dir=save_dir, + save_interval=100, + save_optimizer=True, + keep_checkpoint_max=3 + ) + + @pytest.fixture + def mock_args(self): + """Create mock args.""" + return Mock() + + @pytest.fixture + def mock_state(self): + """Create mock state.""" + state = Mock() + state.global_step = 0 + state.epoch = 0.0 + state.epoch_step = 100 + state.global_batch_size = 32 + return state + + @pytest.fixture + def mock_model(self): + """Create mock model.""" + return Mock() + + @pytest.fixture + def mock_optimizer(self): + """Create mock optimizer.""" + return Mock() + + def test_init_default_parameters(self): + """Test initialization with default parameters.""" + cb = CheckpointCallback(save_dir='/path/to/ckpts') + + assert cb.save_dir == '/path/to/ckpts' + assert cb.save_interval == 1000 + assert cb.save_optimizer is True + assert cb.keep_checkpoint_max == 5 + assert cb.save_on_train_end is True + assert cb.user_prefix == "checkpoint" + assert cb.async_save is False + assert cb.remove_redundancy is False + assert cb.async_save_manager is None + + def test_init_custom_parameters(self): + """Test initialization with custom parameters.""" + cb = CheckpointCallback( + save_dir='/path/to/ckpts', + save_interval=500, + save_optimizer=False, + keep_checkpoint_max=10, + save_on_train_end=False, + user_prefix="my_ckpt", + async_save=True, + remove_redundancy=True + ) + + assert cb.save_dir == '/path/to/ckpts' + assert cb.save_interval == 500 + assert cb.save_optimizer is False + assert cb.keep_checkpoint_max == 10 + assert cb.save_on_train_end is False + assert cb.user_prefix == "my_ckpt" + assert cb.async_save is True + assert cb.remove_redundancy is True + + @patch('mindformers.core.callback_pynative.checkpoint_callback.AsyncSaveManager') + def test_init_with_async_save(self, mock_async_manager_class, save_dir): + """Test initialization creates AsyncSaveManager when async_save is True.""" + mock_async_manager = Mock() + mock_async_manager_class.return_value = mock_async_manager + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + cb = CheckpointCallback( + save_dir=save_dir, + async_save=True + ) + + # Verify AsyncSaveManager was created + mock_async_manager_class.assert_called_once_with(async_save=True) + assert cb.async_save_manager == mock_async_manager + + def test_on_train_begin_creates_directory(self, callback, save_dir, mock_args, mock_state): + """Test on_train_begin creates save directory.""" + assert not os.path.exists(save_dir) + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_train_begin(mock_args, mock_state) + + assert os.path.exists(save_dir) + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_on_step_end_saves_at_interval( + self, mock_save_checkpoint, callback, save_dir, mock_args, mock_state, mock_model, mock_optimizer + ): + """Test on_step_end saves checkpoint at interval.""" + os.makedirs(save_dir, exist_ok=True) + + # Step 50 - should not save + mock_state.global_step = 50 + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_step_end(mock_args, mock_state, model=mock_model, optimizer=mock_optimizer) + + mock_save_checkpoint.assert_not_called() + + # Step 100 - should save + mock_state.global_step = 100 + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_step_end(mock_args, mock_state, model=mock_model, optimizer=mock_optimizer) + + # Verify save_checkpoint was called + assert mock_save_checkpoint.called + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_save_checkpoint_with_common_info( + self, mock_save_checkpoint, callback, save_dir, mock_args, mock_state, mock_model + ): + """Test checkpoint saves with CommonInfo (always enabled).""" + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 100 + mock_state.epoch = 1.5 + mock_state.epoch_step = 100 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_step_end(mock_args, mock_state, model=mock_model) + + # Verify common_info was passed + call_args = mock_save_checkpoint.call_args + assert 'common_info' in call_args[1] + common_info = call_args[1]['common_info'] + assert common_info is not None + assert common_info.global_step == 100 + assert common_info.epoch_num == 1 + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_save_checkpoint_with_async_save( + self, mock_save_checkpoint, save_dir, mock_args, mock_state, mock_model + ): + """Test checkpoint saves with async save enabled.""" + mock_async_manager = Mock() + + with patch('mindformers.core.callback_pynative.checkpoint_callback.AsyncSaveManager') as mock_async_class: + mock_async_class.return_value = mock_async_manager + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + cb = CheckpointCallback( + save_dir=save_dir, + save_interval=100, + async_save=True + ) + + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 100 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + cb.on_step_end(mock_args, mock_state, model=mock_model) + + # Verify prepare_before_save was called + mock_async_manager.prepare_before_save.assert_called_once() + + # Verify async_save_manager was passed + call_args = mock_save_checkpoint.call_args + assert call_args[1]['async_save_manager'] == mock_async_manager + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_save_checkpoint_with_remove_redundancy( + self, mock_save_checkpoint, save_dir, mock_args, mock_state, mock_model + ): + """Test checkpoint saves with remove_redundancy enabled.""" + cb = CheckpointCallback( + save_dir=save_dir, + save_interval=100, + remove_redundancy=True + ) + + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 100 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + cb.on_step_end(mock_args, mock_state, model=mock_model) + + # Verify remove_redundancy was passed + call_args = mock_save_checkpoint.call_args + assert call_args[1]['remove_redundancy'] is True + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_save_checkpoint_all_parameters( + self, mock_save_checkpoint, save_dir, mock_args, mock_state, mock_model, mock_optimizer + ): + """Test all parameters are passed correctly to save_checkpoint.""" + cb = CheckpointCallback( + save_dir=save_dir, + save_interval=100, + save_optimizer=True, + keep_checkpoint_max=5, + user_prefix="test", + async_save=False, + remove_redundancy=True + ) + + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 200 + mock_state.epoch = 2.0 + mock_state.epoch_step = 100 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + cb.on_step_end(mock_args, mock_state, model=mock_model, optimizer=mock_optimizer) + + # Verify all expected parameters + mock_save_checkpoint.assert_called_once() + call_args = mock_save_checkpoint.call_args + + assert call_args[1]['iteration'] == 200 + assert call_args[1]['network'] == mock_model + assert call_args[1]['optimizer'] == mock_optimizer + assert call_args[1]['async_save_manager'] is None # Not enabled + assert call_args[1]['common_info'] is not None + assert call_args[1]['keep_max_num'] == 5 + assert call_args[1]['user_prefix'] == "test" + assert call_args[1]['save_checkpoint_path'] == save_dir + assert call_args[1]['remove_redundancy'] is True + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_create_common_info_from_state( + self, mock_save_checkpoint, callback, save_dir, mock_args, mock_state, mock_model + ): + """Test CommonInfo creation from TrainerState.""" + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 200 # Changed to 200 (multiple of save_interval=100) + mock_state.epoch = 2.0 + mock_state.epoch_step = 100 + mock_state.global_batch_size = 64 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_step_end(mock_args, mock_state, model=mock_model) + + # Get common_info from call + call_args = mock_save_checkpoint.call_args + common_info = call_args[1]['common_info'] + + # Verify CommonInfo fields + assert common_info.epoch_num == 2 + assert common_info.global_step == 200 + assert common_info.step_num == 0 # 200 % 100 = 0 + assert common_info.global_batch_size == 64 + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_on_train_end_final_checkpoint( + self, mock_save_checkpoint, callback, save_dir, mock_args, mock_state, mock_model + ): + """Test on_train_end saves final checkpoint with same user_prefix (no _final suffix).""" + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 1000 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger'): + callback.on_train_end(mock_args, mock_state, model=mock_model) + + # Verify save_checkpoint was called with same prefix (no _final) + call_args = mock_save_checkpoint.call_args + assert call_args[1]['user_prefix'] == "checkpoint" # Default user_prefix + assert call_args[1]['keep_max_num'] == 3 # Same as callback.keep_checkpoint_max + + @patch('mindformers.core.callback_pynative.checkpoint_callback.save_checkpoint') + def test_error_handling( + self, mock_save_checkpoint, callback, save_dir, mock_args, mock_state, mock_model + ): + """Test error handling during checkpoint save.""" + os.makedirs(save_dir, exist_ok=True) + mock_state.global_step = 100 + + # Mock save_checkpoint to raise error + mock_save_checkpoint.side_effect = Exception("Save error") + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger') as mock_logger: + callback.on_step_end(mock_args, mock_state, model=mock_model) + + # Should log error message + mock_logger.error.assert_called_once() + assert "Error saving checkpoint" in str(mock_logger.error.call_args) + + def test_no_model_warning(self, callback, mock_args, mock_state): + """Test warning when no model is provided.""" + mock_state.global_step = 100 + + with patch('mindformers.core.callback_pynative.checkpoint_callback.logger') as mock_logger: + callback.on_step_end(mock_args, mock_state) + + # Should log warning + mock_logger.warning.assert_called_once() + assert "No model provided" in str(mock_logger.warning.call_args) + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/st/test_ut/test_callback/test_loss_callback.py b/tests/st/test_ut/test_callback/test_loss_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..f6a4b4b18fa2fb110a92f4f5c7139d3da5ebd2f9 --- /dev/null +++ b/tests/st/test_ut/test_callback/test_loss_callback.py @@ -0,0 +1,242 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for LossCallback using pytest.""" +import os +import sys +import time +from unittest.mock import Mock, patch + +import pytest + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) + +from mindformers.core.callback_pynative import LossCallback # pylint: disable=wrong-import-position + + +class TestLossCallback: + """Test cases for LossCallback.""" + + @pytest.fixture + def callback(self): + """Create callback fixture.""" + return LossCallback(log_interval=1) + + @pytest.fixture + def mock_args(self): + """Create mock args.""" + return Mock() + + @pytest.fixture + def mock_state(self): + """Create mock state.""" + state = Mock() + state.global_step = 0 + state.max_steps = 1000 + state.epoch = 0.0 + return state + + def test_init(self): + """Test initialization.""" + cb = LossCallback(log_interval=10) + assert cb.log_interval == 10 + + def test_on_train_begin(self, callback, mock_args, mock_state): + """Test on_train_begin resets state.""" + callback.on_train_begin(mock_args, mock_state) + + assert callback.step_time is not None + assert callback.epoch_time is not None + + def test_on_epoch_begin(self, callback, mock_args, mock_state): + """Test on_epoch_begin updates epoch time.""" + old_time = callback.epoch_time + time.sleep(0.01) + callback.on_epoch_begin(mock_args, mock_state) + + assert callback.epoch_time != old_time + + def test_on_step_begin(self, callback, mock_args, mock_state): + """Test on_step_begin updates step time.""" + old_time = callback.step_time + time.sleep(0.01) + callback.on_step_begin(mock_args, mock_state) + + assert callback.step_time != old_time + + def test_on_step_end_with_no_loss(self, callback, mock_args, mock_state): + """Test on_step_end when no loss provided.""" + # Should handle gracefully without error + callback.on_step_end(mock_args, mock_state) + + def test_on_step_end_with_float_loss(self, callback, mock_args, mock_state): + """Test on_step_end with float loss.""" + mock_state.global_step = 1 + mock_state.max_steps = 1000 + loss = 0.5 + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback.on_step_end(mock_args, mock_state, loss=loss) + # Should log the loss in MFLossMonitor format + mock_logger.info.assert_called() + call_args = str(mock_logger.info.call_args) + # Check for MFLossMonitor-style format: "step:[X/Y]" + assert 'step:[' in call_args + + def test_on_step_end_with_tensor_loss(self, callback, mock_args, mock_state): + """Test on_step_end with tensor-like loss.""" + mock_state.global_step = 1 + mock_state.max_steps = 1000 + + # Mock tensor with asnumpy method + loss = Mock() + loss.asnumpy = Mock(return_value=0.5) + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback.on_step_end(mock_args, mock_state, loss=loss) + # Should log the loss + mock_logger.info.assert_called() + call_args = str(mock_logger.info.call_args) + assert 'loss:' in call_args + + def test_on_step_end_logs_at_interval(self, mock_args, mock_state): + """Test that logging happens at specified interval.""" + cb = LossCallback(log_interval=2) + cb.on_train_begin(mock_args, mock_state) + + # Step 1 - should not log + mock_state.global_step = 1 + mock_state.max_steps = 1000 + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + cb.on_step_end(mock_args, mock_state, loss=0.5) + # Logger should not be called (step 1 % 2 != 0) + mock_logger.info.assert_not_called() + + # Step 2 - should log + mock_state.global_step = 2 + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + cb.on_step_end(mock_args, mock_state, loss=0.4) + # Logger should be called + mock_logger.info.assert_called() + call_args = str(mock_logger.info.call_args) + # Verify MFLossMonitor-style format + assert 'per_step_time:' in call_args + + def test_on_step_end_with_lr_scheduler(self, callback, mock_args, mock_state): + """Test on_step_end includes learning rate when available.""" + mock_state.global_step = 1 + mock_state.max_steps = 1000 + + lr_scheduler = Mock() + lr_scheduler.get_last_lr = Mock(return_value=0.001) + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback.on_step_end( + mock_args, + mock_state, + loss=0.5, + lr_scheduler=lr_scheduler + ) + # Should log with learning rate + mock_logger.info.assert_called() + # Note: LR extraction from lr_scheduler is TODO, so won't appear yet + + def test_on_step_end_multiple_steps(self, mock_args, mock_state): + """Test that on_step_end works correctly over multiple steps.""" + cb = LossCallback(log_interval=1) + cb.on_train_begin(mock_args, mock_state) + + # Add 3 losses + for i in range(3): + mock_state.global_step = i + 1 + mock_state.max_steps = 1000 + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + cb.on_step_end(mock_args, mock_state, loss=float(i + 1)) + # Should log each step + mock_logger.info.assert_called() + # Verify format includes step info + call_args = str(mock_logger.info.call_args) + assert f'{i+1:5d}/1000' in call_args or 'step:[' in call_args + + def test_on_epoch_end(self, callback, mock_args, mock_state): + """Test on_epoch_end prints epoch info.""" + mock_state.epoch = 1.0 + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback.on_epoch_end(mock_args, mock_state) + + # Should log epoch info + mock_logger.info.assert_called() + call_args = str(mock_logger.info.call_args) + assert 'Epoch' in call_args + + def test_on_epoch_end_always_logs(self, callback, mock_args, mock_state): + """Test on_epoch_end always logs epoch info.""" + mock_state.epoch = 2.0 + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback.on_epoch_end(mock_args, mock_state) + + # Should always log epoch completion + mock_logger.info.assert_called() + + def test_print_log_format(self, callback): + """Test _print_log formats correctly in MFLossMonitor style.""" + log_info = { + 'cur_step': 100, + 'max_steps': 1000, + 'loss': 0.123456, + 'learning_rate': 0.001, + 'step_time': 123 + } + + mock_state = Mock() + mock_state.global_step = 100 + mock_state.max_steps = 1000 + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback._print_log(log_info, mock_state) # pylint: disable=protected-access + + # Check logger was called + mock_logger.info.assert_called_once() + logged_str = str(mock_logger.info.call_args[0][0]) + + # Verify MFLossMonitor-style format + assert 'step:[' in logged_str + assert 'loss:' in logged_str + assert 'per_step_time:' in logged_str + assert 'lr:' in logged_str + + def test_print_log_with_list_lr(self, callback): + """Test _print_log handles list learning rate.""" + log_info = { + 'cur_step': 100, + 'max_steps': 1000, + 'loss': 0.5, + 'step_time': 100, + 'learning_rate': [0.001, 0.002] # List of LRs + } + + mock_state = Mock() + mock_state.global_step = 100 + mock_state.max_steps = 1000 + + with patch('mindformers.core.callback_pynative.loss_callback.logger') as mock_logger: + callback._print_log(log_info, mock_state) # pylint: disable=protected-access + + mock_logger.info.assert_called_once() + logged_str = str(mock_logger.info.call_args[0][0]) + # Should use first LR from list + assert 'lr:' in logged_str diff --git a/tests/st/test_ut/test_trainer/__init__.py b/tests/st/test_ut/test_trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c509819750e047990eceeead148da4b7032c60b2 --- /dev/null +++ b/tests/st/test_ut/test_trainer/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for trainer module.""" diff --git a/tests/st/test_ut/test_trainer/test_train_state.py b/tests/st/test_ut/test_trainer/test_train_state.py new file mode 100644 index 0000000000000000000000000000000000000000..cc69e1ce0e979b5afa4e5ddbefb1247b883cbe8f --- /dev/null +++ b/tests/st/test_ut/test_trainer/test_train_state.py @@ -0,0 +1,214 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for TrainerState.""" +import os +import sys + +import pytest + +# Add project root to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../..')) + +from mindformers.trainer_pynative.train_state import TrainerState # pylint: disable=wrong-import-position + + +class TestTrainerState: + """Test cases for TrainerState.""" + + def test_init_default_parameters(self): + """Test initialization with default parameters.""" + state = TrainerState() + + assert state.epoch == 0.0 + assert state.global_step == 0 + assert state.max_steps == 0 + assert state.eval_steps == 0 + assert state.save_steps == 0 + assert state.epoch_step == 0 + assert state.global_batch_size == 0 + assert state.best_metric is None + assert state.best_model_checkpoint is None + assert not state.is_train_begin + assert not state.is_train_end + + def test_init_custom_parameters(self): + """Test initialization with custom parameters.""" + state = TrainerState( + epoch=2.5, + global_step=1000, + max_steps=5000, + eval_steps=100, + save_steps=500, + epoch_step=200, + global_batch_size=64, + best_metric=0.95 + ) + + assert state.epoch == 2.5 + assert state.global_step == 1000 + assert state.max_steps == 5000 + assert state.eval_steps == 100 + assert state.save_steps == 500 + assert state.epoch_step == 200 + assert state.global_batch_size == 64 + assert state.best_metric == 0.95 + + def test_update_epoch(self): + """Test update_epoch method.""" + state = TrainerState( + global_step=250, + epoch_step=100 + ) + + state.update_epoch() + + assert state.epoch == 2.5 + + def test_update_epoch_zero_epoch_step(self): + """Test update_epoch with zero epoch_step.""" + state = TrainerState( + global_step=100, + epoch_step=0 + ) + + # Should not crash + state.update_epoch() + assert state.epoch == 0.0 + + def test_save_to_dict(self): + """Test save_to_dict method.""" + state = TrainerState( + epoch=2.5, + global_step=1000, + max_steps=5000, + eval_steps=100, + save_steps=500, + epoch_step=200, + global_batch_size=64, + best_metric=0.95, + best_model_checkpoint="/path/to/ckpt" + ) + + state_dict = state.save_to_dict() + + assert state_dict["epoch"] == 2.5 + assert state_dict["global_step"] == 1000 + assert state_dict["max_steps"] == 5000 + assert state_dict["eval_steps"] == 100 + assert state_dict["save_steps"] == 500 + assert state_dict["epoch_step"] == 200 + assert state_dict["global_batch_size"] == 64 + assert state_dict["best_metric"] == 0.95 + assert state_dict["best_model_checkpoint"] == "/path/to/ckpt" + + def test_load_from_dict(self): + """Test load_from_dict method.""" + state_dict = { + "epoch": 3.0, + "global_step": 1500, + "max_steps": 6000, + "eval_steps": 150, + "save_steps": 600, + "epoch_step": 300, + "global_batch_size": 128, + "best_metric": 0.98, + "best_model_checkpoint": "/path/to/best" + } + + state = TrainerState.load_from_dict(state_dict) + + assert state.epoch == 3.0 + assert state.global_step == 1500 + assert state.max_steps == 6000 + assert state.eval_steps == 150 + assert state.save_steps == 600 + assert state.epoch_step == 300 + assert state.global_batch_size == 128 + assert state.best_metric == 0.98 + assert state.best_model_checkpoint == "/path/to/best" + + def test_save_and_load_roundtrip(self): + """Test saving and loading state.""" + original_state = TrainerState( + epoch=4.2, + global_step=2000, + max_steps=8000, + eval_steps=200, + save_steps=800, + epoch_step=400, + global_batch_size=256, + best_metric=0.99 + ) + + # Save to dict + state_dict = original_state.save_to_dict() + + # Load from dict + loaded_state = TrainerState.load_from_dict(state_dict) + + # Verify all fields match + assert loaded_state.epoch == original_state.epoch + assert loaded_state.global_step == original_state.global_step + assert loaded_state.max_steps == original_state.max_steps + assert loaded_state.eval_steps == original_state.eval_steps + assert loaded_state.save_steps == original_state.save_steps + assert loaded_state.epoch_step == original_state.epoch_step + assert loaded_state.global_batch_size == original_state.global_batch_size + assert loaded_state.best_metric == original_state.best_metric + + def test_repr(self): + """Test string representation.""" + state = TrainerState( + epoch=2.5, + global_step=1000, + max_steps=5000 + ) + + repr_str = repr(state) + + assert "TrainerState" in repr_str + assert "epoch=2.5" in repr_str + assert "global_step=1000" in repr_str + assert "max_steps=5000" in repr_str + + def test_global_batch_size_in_save_to_dict(self): + """Test that global_batch_size is included in save_to_dict.""" + state = TrainerState(global_batch_size=32) + state_dict = state.save_to_dict() + + assert "global_batch_size" in state_dict + assert state_dict["global_batch_size"] == 32 + + def test_global_batch_size_in_load_from_dict(self): + """Test that global_batch_size is loaded from dict.""" + state_dict = { + "epoch": 0.0, + "global_step": 0, + "max_steps": 1000, + "eval_steps": 100, + "save_steps": 100, + "epoch_step": 100, + "global_batch_size": 64, + "best_metric": None, + "best_model_checkpoint": None + } + + state = TrainerState.load_from_dict(state_dict) + + assert state.global_batch_size == 64 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/st/test_ut/test_trainer/test_trainer.py b/tests/st/test_ut/test_trainer/test_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d77002f64f61e7f52f84975ce9f1d5b95d1f3704 --- /dev/null +++ b/tests/st/test_ut/test_trainer/test_trainer.py @@ -0,0 +1,505 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Unit tests for Trainer public interfaces: train() and get_batch().""" +# pylint: disable=protected-access +import os +import sys +from unittest.mock import Mock, patch + +import pytest + + +# Add project root to path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')) +sys.path.insert(0, project_root) + +# Mock mindspore modules FIRST before any mindformers imports +mindspore_mock = Mock() +mindspore_mock.nn = Mock() +mindspore_mock.dataset = Mock() +mindspore_mock.common = Mock() +mindspore_mock.common.tensor = Mock() +mindspore_mock.common.initializer = Mock() +mindspore_mock.common.parameter = Mock() +mindspore_mock.common.dtype = Mock() +mindspore_mock.ops = Mock() +mindspore_mock.context = Mock() + +sys.modules['mindspore'] = mindspore_mock +sys.modules['mindspore.nn'] = mindspore_mock.nn +sys.modules['mindspore.dataset'] = mindspore_mock.dataset +sys.modules['mindspore.common'] = mindspore_mock.common +sys.modules['mindspore.common.tensor'] = mindspore_mock.common.tensor +sys.modules['mindspore.common.initializer'] = mindspore_mock.common.initializer +sys.modules['mindspore.common.parameter'] = mindspore_mock.common.parameter +sys.modules['mindspore.common.dtype'] = mindspore_mock.common.dtype +sys.modules['mindspore.ops'] = mindspore_mock.ops +sys.modules['mindspore.context'] = mindspore_mock.context +sys.modules['mindspore.parallel'] = Mock() +sys.modules['mindspore.train'] = Mock() + +# Mock mindspore.communication +communication_mock = Mock() +communication_mock.management = Mock() +communication_mock.management.get_rank = Mock(return_value=0) +communication_mock.management.get_group_size = Mock(return_value=1) +sys.modules['mindspore.communication'] = communication_mock +sys.modules['mindspore.communication.management'] = communication_mock.management + +sys.modules['mindspore._checkparam'] = Mock() +sys.modules['mindspore.amp'] = Mock() +sys.modules['mindspore._c_expression'] = Mock() + +# Mock mindformers modules +modules_mock = Mock() +modules_mock.transformer = Mock() +modules_mock.__all__ = [] +sys.modules['mindformers.modules'] = modules_mock +sys.modules['mindformers.modules.transformer'] = modules_mock.transformer + +checkpoint_mock = Mock() +checkpoint_mock.__all__ = [] +checkpoint_mock.checkpoint = Mock() +sys.modules['mindformers.checkpoint'] = checkpoint_mock +sys.modules['mindformers.checkpoint.checkpoint'] = checkpoint_mock.checkpoint + +models_mock = Mock() +models_mock.llama = Mock() +models_mock.__all__ = [] +sys.modules['mindformers.models'] = models_mock +sys.modules['mindformers.models.llama'] = models_mock.llama + +dataset_mock = Mock() +dataset_mock.__all__ = [] +sys.modules['mindformers.dataset'] = dataset_mock + +sys.modules['mindformers.run_check'] = Mock() + +core_mock = Mock() +core_mock.context = Mock() +core_mock.__all__ = [] +sys.modules['mindformers.core'] = core_mock +sys.modules['mindformers.core.context'] = core_mock.context +sys.modules['mindformers.core.config_args'] = Mock() +sys.modules['mindformers.core.lr'] = Mock() +sys.modules['mindformers.core.optim'] = Mock() +sys.modules['mindformers.core.callback'] = Mock() +sys.modules['mindformers.core.callback_pynative'] = Mock() +sys.modules['mindformers.core.metric'] = Mock() + +pet_mock = Mock() +pet_mock.__all__ = [] +sys.modules['mindformers.pet'] = pet_mock + +wrapper_mock = Mock() +wrapper_mock.__all__ = [] +sys.modules['mindformers.wrapper'] = wrapper_mock + +generation_mock = Mock() +generation_mock.__all__ = [] +sys.modules['mindformers.generation'] = generation_mock + +pipeline_mock = Mock() +pipeline_mock.__all__ = [] +sys.modules['mindformers.pipeline'] = pipeline_mock + +trainer_mock = Mock() +trainer_mock.__all__ = [] +trainer_mock.training_args = Mock() +trainer_mock.optimizer_grouped_parameters = Mock() +sys.modules['mindformers.trainer'] = trainer_mock +sys.modules['mindformers.trainer.training_args'] = trainer_mock.training_args +sys.modules['mindformers.trainer.optimizer_grouped_parameters'] = trainer_mock.optimizer_grouped_parameters +sys.modules['mindformers.trainer.general_task_trainer'] = Mock() + +sys.modules['mindformers.model_runner'] = Mock() + +# Mock trainer.train_state +train_state_mock = Mock() +mock_trainer_state_class = Mock() +sys.modules['trainer'] = Mock() +sys.modules['trainer.train_state'] = train_state_mock +train_state_mock.TrainerState = mock_trainer_state_class + +# Import real MindFormerConfig +from mindformers.tools.register import MindFormerConfig # pylint: disable=wrong-import-position + +# Now import Trainer +from mindformers.trainer_pynative.trainer import Trainer # pylint: disable=wrong-import-position + + +class TestTrainerTrain: + """Test cases for Trainer.train() interface.""" + + @pytest.fixture + def mock_config(self): + """Create a mock MindFormerConfig.""" + config = MindFormerConfig() + config.max_steps = 100 + config.eval_steps = 20 + config.save_steps = 50 + config.global_batch_size = 32 + return config + + @pytest.fixture + def mock_trainer_state(self): + """Create a mock TrainerState.""" + state = Mock() + state.global_step = 0 + state.epoch_step = 10 + state.max_steps = 100 + state.eval_steps = 20 + state.save_steps = 50 + state.global_batch_size = 32 + state.update_epoch = Mock() + return state + + @pytest.fixture + def mock_model(self): + """Create a mock model.""" + model = Mock() + model.__call__ = Mock(return_value={'loss': 0.5}) + return model + + @pytest.fixture + def mock_dataset(self): + """Create a mock dataset.""" + dataset = Mock() + dataset.__len__ = Mock(return_value=10) + dataset.get_dataset_size = Mock(return_value=10) + + # Create mock iterator + mock_iter = Mock() + mock_iter.__next__ = Mock(side_effect=[ + {'input_ids': [1, 2, 3], 'labels': [2, 3, 4]}, + {'input_ids': [4, 5, 6], 'labels': [5, 6, 7]}, + ] * 100) # Repeat to avoid StopIteration during tests + + dataset.create_dict_iterator = Mock(return_value=mock_iter) + return dataset + + @pytest.fixture + def mock_optimizer(self): + """Create a mock optimizer.""" + optimizer = Mock() + optimizer.step = Mock() + return optimizer + + @pytest.fixture + def mock_callback_handler(self): + """Create a mock CallbackHandler.""" + handler = Mock() + handler.on_train_begin = Mock() + handler.on_train_end = Mock() + handler.on_epoch_begin = Mock() + handler.on_epoch_end = Mock() + handler.on_step_begin = Mock() + handler.on_step_end = Mock() + return handler + + def test_train_pretrain_mode_success( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler, mock_trainer_state + ): + """Test train() in pretrain mode executes successfully.""" + # Create trainer instance with mocked components + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.eval_dataset = None + trainer.optimizer = mock_optimizer + trainer.lr_scheduler = Mock() + trainer.callback_handler = mock_callback_handler + trainer.compute_metrics = None + trainer.compute_loss_func = None + trainer.processing_class = None + + # Mock internal methods + trainer._init_parallel_config = Mock() + trainer._load_checkpoint = Mock() + trainer._get_dataset_size = Mock(return_value=10) + trainer._inner_train_loop = Mock() + + # Mock TrainerState + with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state): + # Execute train + trainer.train(checkpoint_path=None, mode="pretrain", do_eval=False) + + # Verify method calls + trainer._init_parallel_config.assert_called_once() + trainer._load_checkpoint.assert_not_called() # No checkpoint in pretrain mode + mock_callback_handler.on_train_begin.assert_called_once() + mock_callback_handler.on_train_end.assert_called_once() + trainer._inner_train_loop.assert_called_once_with(False) + + def test_train_finetune_mode_requires_checkpoint( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler + ): + """Test train() in finetune mode raises error without checkpoint.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.eval_dataset = None + trainer.optimizer = mock_optimizer + trainer.lr_scheduler = Mock() + trainer.callback_handler = mock_callback_handler + trainer.compute_metrics = None + trainer.compute_loss_func = None + trainer.processing_class = None + + trainer._init_parallel_config = Mock() + + # Should raise ValueError when checkpoint_path is None in finetune mode + with pytest.raises(ValueError, match="checkpoint_path cannot be None"): + trainer.train(checkpoint_path=None, mode="finetune", do_eval=False) + + def test_train_finetune_mode_with_checkpoint( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler, mock_trainer_state + ): + """Test train() in finetune mode with checkpoint loads correctly.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.eval_dataset = None + trainer.optimizer = mock_optimizer + trainer.lr_scheduler = Mock() + trainer.callback_handler = mock_callback_handler + trainer.compute_metrics = None + trainer.compute_loss_func = None + trainer.processing_class = None + + trainer._init_parallel_config = Mock() + trainer._load_checkpoint = Mock() + trainer._get_dataset_size = Mock(return_value=10) + trainer._inner_train_loop = Mock() + + checkpoint_path = "/mock/checkpoint.ckpt" + + with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state): + with patch('os.path.exists', return_value=True): + trainer.train(checkpoint_path=checkpoint_path, mode="finetune", do_eval=False) + + # Verify checkpoint loading + trainer._load_checkpoint.assert_called_once_with(checkpoint_path, "finetune") + + def test_train_invalid_mode_raises_error( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler + ): + """Test train() raises error with invalid mode.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.optimizer = mock_optimizer + trainer.callback_handler = mock_callback_handler + + trainer._init_parallel_config = Mock() + + with pytest.raises(ValueError, match="mode must be 'pretrain' or 'finetune'"): + trainer.train(checkpoint_path=None, mode="invalid_mode", do_eval=False) + + def test_train_calls_callbacks_correctly( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler, mock_trainer_state + ): + """Test train() calls all callback hooks in correct order.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.optimizer = mock_optimizer + trainer.lr_scheduler = Mock() + trainer.callback_handler = mock_callback_handler + trainer.compute_loss_func = None + + trainer._init_parallel_config = Mock() + trainer._load_checkpoint = Mock() + trainer._get_dataset_size = Mock(return_value=10) + trainer._inner_train_loop = Mock() + + with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state): + trainer.train(checkpoint_path=None, mode="pretrain", do_eval=False) + + # Verify callback call order + assert mock_callback_handler.on_train_begin.called + assert mock_callback_handler.on_train_end.called + # on_train_begin should be called before on_train_end + call_order = [ + call for call in mock_callback_handler.method_calls + if call[0] in ['on_train_begin', 'on_train_end'] + ] + assert call_order[0][0] == 'on_train_begin' + assert call_order[-1][0] == 'on_train_end' + + def test_train_with_do_eval_true( + self, mock_config, mock_model, mock_dataset, + mock_optimizer, mock_callback_handler, mock_trainer_state + ): + """Test train() with do_eval=True passes flag to inner loop.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + trainer.model = mock_model + trainer.train_dataset = mock_dataset + trainer.optimizer = mock_optimizer + trainer.lr_scheduler = Mock() + trainer.callback_handler = mock_callback_handler + + trainer._init_parallel_config = Mock() + trainer._get_dataset_size = Mock(return_value=10) + trainer._inner_train_loop = Mock() + + with patch('trainer.train_state.TrainerState', return_value=mock_trainer_state): + trainer.train(checkpoint_path=None, mode="pretrain", do_eval=True) + + # Verify do_eval flag is passed + trainer._inner_train_loop.assert_called_once_with(True) + + +class TestTrainerGetBatch: + """Test cases for Trainer.get_batch() interface.""" + + @pytest.fixture + def mock_config(self): + """Create a mock config.""" + config = Mock() + config.use_distribute_dataset = False + config.use_remove_redundant_dataset = False + return config + + @pytest.fixture + def mock_dataset_iter(self): + """Create a mock dataset iterator.""" + iterator = Mock() + iterator.__next__ = Mock(return_value={'input_ids': [1, 2, 3], 'labels': [2, 3, 4]}) + return iterator + + def test_get_batch_naive_mode_returns_dict(self, mock_config, mock_dataset_iter): + """Test get_batch() in naive mode returns dict data.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + # Execute get_batch + batch = trainer.get_batch(mock_dataset_iter) + + # Verify return type and content + assert isinstance(batch, dict) + assert 'input_ids' in batch + assert batch['input_ids'] == [1, 2, 3] + + def test_get_batch_distributed_mode(self, mock_config, mock_dataset_iter): + """Test get_batch() in distributed mode.""" + mock_config.use_distribute_dataset = True + + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + batch = trainer.get_batch(mock_dataset_iter) + + assert isinstance(batch, dict) + assert 'input_ids' in batch + + def test_get_batch_remove_redundant_mode(self, mock_config, mock_dataset_iter): + """Test get_batch() in remove redundant mode.""" + mock_config.use_remove_redundant_dataset = True + + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + batch = trainer.get_batch(mock_dataset_iter) + + assert isinstance(batch, dict) + assert 'input_ids' in batch + + def test_get_batch_handles_tuple_data(self, mock_config): + """Test get_batch() converts tuple data to dict.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + # Mock iterator returning tuple + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=([1, 2, 3], [2, 3, 4])) + + batch = trainer.get_batch(mock_iter) + + # Should convert tuple to dict with 'input_ids' key + assert isinstance(batch, dict) + assert 'input_ids' in batch + + def test_get_batch_handles_list_data(self, mock_config): + """Test get_batch() converts list data to dict.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + # Mock iterator returning list + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=[[1, 2, 3], [2, 3, 4]]) + + batch = trainer.get_batch(mock_iter) + + # Should convert list to dict with 'input_ids' key + assert isinstance(batch, dict) + assert 'input_ids' in batch + + def test_get_batch_handles_none_data(self, mock_config): + """Test get_batch() handles None data gracefully.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + # Mock iterator returning None + mock_iter = Mock() + mock_iter.__next__ = Mock(return_value=None) + + batch = trainer.get_batch(mock_iter) + + # Should return empty dict + assert isinstance(batch, dict) + assert len(batch) == 0 + + def test_get_batch_calls_correct_internal_method(self, mock_config, mock_dataset_iter): + """Test get_batch() calls the correct internal method based on config.""" + trainer = Trainer.__new__(Trainer) + trainer.config = mock_config + + # Mock internal methods + trainer._get_batch_naive = Mock(return_value={'input_ids': [1, 2, 3]}) + trainer._get_batch_distributed = Mock() + trainer._get_batch_remove_redundant = Mock() + + # Test naive mode + trainer.get_batch(mock_dataset_iter) + trainer._get_batch_naive.assert_called_once() + + # Test distributed mode + trainer.config.use_distribute_dataset = True + trainer._get_batch_distributed.reset_mock() + trainer.get_batch(mock_dataset_iter) + trainer._get_batch_distributed.assert_called_once() + + # Test remove redundant mode + trainer.config.use_distribute_dataset = False + trainer.config.use_remove_redundant_dataset = True + trainer._get_batch_remove_redundant.reset_mock() + trainer.get_batch(mock_dataset_iter) + trainer._get_batch_remove_redundant.assert_called_once() + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) diff --git a/tests/st/test_ut/test_trainer/test_trainstate_integration.py b/tests/st/test_ut/test_trainer/test_trainstate_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..88e82ee3e57db3c5d19806d144ddca507c540d2e --- /dev/null +++ b/tests/st/test_ut/test_trainer/test_trainstate_integration.py @@ -0,0 +1,266 @@ +"""Integration test for TrainerState with callbacks.""" +# pylint: disable=wrong-import-position,import-outside-toplevel,protected-access +import importlib.util +import os +import sys +import traceback +from unittest.mock import Mock + +# Add project root to path +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../..')) +sys.path.insert(0, project_root) + +# Mock mindspore modules +sys.modules['mindspore'] = Mock() +sys.modules['mindspore.nn'] = Mock() +sys.modules['mindspore.dataset'] = Mock() + +# Mock mindformers modules +mock_logger = Mock() +sys.modules['mindformers'] = Mock() +sys.modules['mindformers.tools'] = Mock() +sys.modules['mindformers.tools.logger'] = Mock(logger=mock_logger) +sys.modules['mindformers.checkpoint'] = Mock() + +# Mock CommonInfo and AsyncSaveManager +class MockCommonInfo: + def __init__(self): + self.epoch_num = None + self.step_num = None + self.global_step = None + self.global_batch_size = None + +class MockAsyncSaveManager: + def __init__(self, async_save): + self.async_save = async_save + def prepare_before_save(self): + pass + +sys.modules['mindformers.checkpoint.checkpoint'] = Mock( + CommonInfo=MockCommonInfo, + AsyncSaveManager=MockAsyncSaveManager +) + +# Direct import to avoid __init__.py issues +train_state_path = os.path.join(project_root, 'mindformers', 'trainer_pynative', 'train_state.py') +spec = importlib.util.spec_from_file_location("train_state", train_state_path) +train_state_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(train_state_module) +TrainerState = train_state_module.TrainerState + +# Now import callbacks +from mindformers.core.callback_pynative.checkpoint_callback import CheckpointCallback +from mindformers.core.callback_pynative.loss_callback import LossCallback + + +def test_trainstate_with_checkpoint_callback(): + """ + Feature: TrainerState integration with CheckpointCallback + Description: Verify TrainerState provides required fields to CheckpointCallback + Expectation: CheckpointCallback successfully creates CommonInfo from TrainerState + """ + print("=" * 70) + print("Test 1: TrainerState with CheckpointCallback") + print("=" * 70) + + # Create TrainerState with all required fields + state = TrainerState( + global_step=100, + epoch=1.0, + epoch_step=100, + global_batch_size=64 + ) + + # Create CheckpointCallback + cb = CheckpointCallback( + save_dir="./test_ckpts", + save_interval=100 + ) + + # Test _create_common_info + common_info = cb._create_common_info(state) + + # Verify all fields are set correctly + assert common_info.global_step == 100, f"Expected 100, got {common_info.global_step}" + assert common_info.epoch_num == 1, f"Expected 1, got {common_info.epoch_num}" + assert common_info.step_num == 0, f"Expected 0, got {common_info.step_num}" + assert common_info.global_batch_size == 64, f"Expected 64, got {common_info.global_batch_size}" + + print("[OK] CheckpointCallback correctly uses TrainerState.global_batch_size") + print(f" global_batch_size: {common_info.global_batch_size}") + + return True + + +def test_trainstate_with_loss_callback(): + """ + Feature: TrainerState integration with LossCallback + Description: Verify that TrainerState works correctly with LossCallback during training step callbacks + Expectation: LossCallback can access and use TrainerState fields without errors + """ + print("\n" + "=" * 70) + print("Test 2: TrainerState with LossCallback") + print("=" * 70) + + # Create TrainerState + state = TrainerState( + global_step=10, + epoch=0.5, + epoch_step=20, + global_batch_size=32 + ) + + # Create LossCallback + cb = LossCallback(log_interval=1) + + # Simulate step_end + args = Mock() + cb.on_step_begin(args, state) + + # Verify state can be used + assert state.global_step == 10 + assert state.epoch == 0.5 + + print("[OK] LossCallback correctly uses TrainerState") + print(f" global_step: {state.global_step}") + print(f" epoch: {state.epoch}") + + return True + + +def test_trainstate_all_required_fields(): + """ + Feature: TrainerState required fields validation + Description: Verify that TrainerState contains all required fields needed by various callbacks + Expectation: All 11 required fields (global_step, epoch, epoch_step, etc.) are present in TrainerState + """ + print("\n" + "=" * 70) + print("Test 3: TrainerState has all required fields") + print("=" * 70) + + state = TrainerState() + + required_fields = [ + 'global_step', + 'epoch', + 'epoch_step', + 'global_batch_size', + 'max_steps', + 'eval_steps', + 'save_steps', + 'best_metric', + 'best_model_checkpoint', + 'is_train_begin', + 'is_train_end', + ] + + missing_fields = [] + for field in required_fields: + if not hasattr(state, field): + missing_fields.append(field) + + if missing_fields: + print(f"[FAIL] Missing fields: {missing_fields}") + return False + + print("[OK] All required fields present") + for field in required_fields: + value = getattr(state, field) + print(f" {field}: {value}") + + return True + + +def test_trainstate_update_epoch(): + """ + Feature: TrainerState epoch calculation + Description: Verify that update_epoch method correctly calculates epoch based on global_step and epoch_step + Expectation: Epoch is calculated as global_step / epoch_step (e.g., 250 / 100 = 2.5) + """ + print("\n" + "=" * 70) + print("Test 4: TrainerState.update_epoch") + print("=" * 70) + + state = TrainerState( + global_step=250, + epoch_step=100 + ) + + state.update_epoch() + + assert state.epoch == 2.5, f"Expected 2.5, got {state.epoch}" + + print("[OK] update_epoch works correctly") + print(f" global_step: {state.global_step}") + print(f" epoch_step: {state.epoch_step}") + print(f" calculated epoch: {state.epoch}") + + return True + + +def test_trainstate_with_different_batch_sizes(): + """ + Feature: TrainerState batch size handling + Description: Verify that TrainerState correctly stores and retrieves different global_batch_size values + Expectation: TrainerState accurately maintains batch size values from 1 to 512 + """ + print("\n" + "=" * 70) + print("Test 5: TrainerState with different batch sizes") + print("=" * 70) + + batch_sizes = [1, 16, 32, 64, 128, 256, 512] + + for batch_size in batch_sizes: + state = TrainerState(global_batch_size=batch_size) + assert state.global_batch_size == batch_size, \ + f"Expected {batch_size}, got {state.global_batch_size}" + + print(f"[OK] Tested {len(batch_sizes)} different batch sizes") + print(f" Batch sizes: {batch_sizes}") + + return True + + +def main(): + """Run all integration tests.""" + print("=" * 70) + print("TrainerState Integration Tests") + print("=" * 70) + + tests = [ + test_trainstate_with_checkpoint_callback, + test_trainstate_with_loss_callback, + test_trainstate_all_required_fields, + test_trainstate_update_epoch, + test_trainstate_with_different_batch_sizes, + ] + + passed = 0 + failed = 0 + + for test_func in tests: + try: + if test_func(): + passed += 1 + except AssertionError as e: + failed += 1 + print(f"\n[FAIL] {test_func.__name__}: {e}") + traceback.print_exc() + except Exception as e: + failed += 1 + print(f"\n[ERROR] {test_func.__name__}: {e}") + traceback.print_exc() + + print("\n" + "=" * 70) + print(f"Integration Test Results: {passed} passed, {failed} failed") + print("=" * 70) + + if failed == 0: + print("[OK] ALL INTEGRATION TESTS PASSED") + return 0 + print("[FAIL] SOME INTEGRATION TESTS FAILED") + return 1 + + +if __name__ == '__main__': + sys.exit(main())