diff --git a/mindformers/core/callback/callback.py b/mindformers/core/callback/callback.py index fb39707b2c0da40167a9308b1bbefb8d3fa4e178..56c3a575415b079a053f0e6cdd7c7ed8da824066 100644 --- a/mindformers/core/callback/callback.py +++ b/mindformers/core/callback/callback.py @@ -1583,13 +1583,9 @@ class CheckpointMonitor(ModelCheckpoint): def remove_redundancy(self, network, cur_file, append_dict, train_network): """remove redundancy when saving checkpoint files.""" - if self._config.remove_redundancy: + parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self._config.remove_redundancy and parallel_mode != "stand_alone": logger.info('......Removing redundancy......') - parallel_mode = context.get_auto_parallel_context("parallel_mode") - if parallel_mode == "stand_alone": - raise TypeError(f"The deduplication feature for saving checkpoint can only be used " - f"in parallel scenarios, but got {parallel_mode}.") - if train_network: param_layout = train_network.parameter_layout_dict else: diff --git a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py index 3dd3357ed93848d150ea2b2190de6cb7e8d78f8a..1a694dfebfe347848b8a870be083a3baaffd4690 100644 --- a/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py +++ b/mindformers/parallel_core/training_graph/base_models/gpt/gpt_model.py @@ -266,8 +266,9 @@ class GPTModel(nn.Cell): self.pp = config.pipeline_model_parallel_size if config.pipeline_model_parallel_size is not None else 1 self.cp = config.context_parallel_size if config.context_parallel_size is not None else 1 - initialize_model_parallel(tensor_model_parallel_size=self.tp, data_parallel_size=self.dp, - pipeline_model_parallel_size=self.pp, context_parallel_size=self.cp) + if _get_parallel_mode() != ParallelMode.STAND_ALONE: + initialize_model_parallel(tensor_model_parallel_size=self.tp, data_parallel_size=self.dp, + pipeline_model_parallel_size=self.pp, context_parallel_size=self.cp) self.preprocess_labels_and_masks = PreprocessLabelsAndMasks(config) diff --git a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py index af89ca425805ac825904bff1ecec44a55e7d8b05..616516ced3c02777be0fbe6a5a34e902df89e32d 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/ffn.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/ffn.py @@ -83,6 +83,12 @@ class FFNGroupedGEMM(nn.Cell): name='w2') # init token dispatcher + if _get_parallel_mode() == ParallelMode.STAND_ALONE: + if self.moe_token_dispatcher_type != "alltoall": + raise ValueError( + f"In STAND_ALONE mode, only 'alltoall' is supported for " + f"moe_token_dispatcher_type, but got {self.moe_token_dispatcher_type!r}.") + if self.moe_token_dispatcher_type == "alltoall": self.token_dispatcher = MoEAlltoAllTokenDispatcher(config) elif self.moe_token_dispatcher_type == "alltoall_deredundency": diff --git a/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py b/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py index b5161dcf0f974cc4b180c4dde91c945753fde19a..16939c523db2343976dc9fcaefb8c9ab1d91a2ba 100644 --- a/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py +++ b/mindformers/parallel_core/training_graph/transformer/moe/token_dispatcher.py @@ -21,8 +21,10 @@ import numpy as np import mindspore as ms import mindspore.ops as ops import mindspore.mint as mint +from mindspore import nn, ParallelMode from mindspore.common.tensor import Tensor from mindspore.communication import create_group, get_rank +from mindspore.parallel._utils import _get_parallel_mode from mindspore.ops.auto_generate import CumsumExt, FmodScalar, SortExt, IndexSelect, OneHotExt, Cast, Reshape, Zeros, Transpose, ReduceSum, MaskedSelect from mindformers.parallel_core.transformer_config import TransformerConfig from mindformers.version_control import get_all2allvc @@ -31,6 +33,36 @@ _OEP_GROUP_NAME = {} _IEP_GROUP_NAME = {} +class AlltoAll(nn.Cell): + """AlltoAll operation wrapper.""" + def __init__(self, split_count, split_dim, concat_dim, group=None): + super().__init__() + self.group_is_none = group is None + if not self.group_is_none: + self.ops = ops.AlltoAll(split_count, split_dim, concat_dim, group) + + def construct(self, input_x): + if self.group_is_none: + return input_x + input_x = self.ops(input_x) + return input_x + + +class AlltoAllV(nn.Cell): + """AlltoAllV operation wrapper.""" + def __init__(self, group=None, block_size=1): + super().__init__() + self.group_is_none = group is None + if not self.group_is_none: + self.ops = ops.AlltoAllV(group=group, block_size=block_size) + + def construct(self, input_x, send_numel_list, recv_numel_list): + if self.group_is_none: + return input_x + tensor = self.ops(input_x, send_numel_list, recv_numel_list) + return tensor + + class MoETokenDispatcher: """ MoE Token Dispatcher @@ -52,6 +84,8 @@ class MoETokenDispatcher: def _ep_group(self): """Get expert model parallel group.""" + if _get_parallel_mode() == ParallelMode.STAND_ALONE: + return None rank_id = get_rank() ep = self.config.expert_model_parallel_size @@ -491,7 +525,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): It also initializes the necessary data structures for AlltoAll communication, such as input and output splits, and the mapping between global tokens and local experts. """ - num_global_tokens_per_expert = ops.AlltoAll( + num_global_tokens_per_expert = AlltoAll( split_count=self.ep, split_dim=-1, concat_dim=-2, @@ -594,10 +628,10 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): # Perform expert parallel AlltoAll communication # The shape change is: global_input_tokens <- [B, S, h] original_shape = permuted_input.shape - global_input_tokens = ops.AlltoAllV(group=self.ep_group, block_size=self.hidden_size)( + global_input_tokens = AlltoAllV(group=self.ep_group, block_size=self.hidden_size)( permuted_input.reshape(-1), input_splits, output_splits).reshape(1, -1, self.hidden_size) # The shape change is: routing_map <- [B, S] - routing_map = ops.AlltoAllV(group=self.ep_group, block_size=1)( + routing_map = AlltoAllV(group=self.ep_group, block_size=1)( routing_map.astype(ms.float32).reshape(-1), input_splits, output_splits).reshape(1, -1) # Permutation 2: Sort tokens by local expert. @@ -662,7 +696,7 @@ class MoEAlltoAllTokenDispatcher(MoETokenDispatcher): # Perform expert parallel AlltoAll communication - permutated_local_input_tokens = ops.AlltoAllV(group=self.ep_group, block_size=self.hidden_size)( + permutated_local_input_tokens = AlltoAllV(group=self.ep_group, block_size=self.hidden_size)( tokens.reshape(-1), output_splits, input_splits).reshape(1, -1, self.hidden_size) permutated_local_input_tokens = permutated_local_input_tokens.reshape(original_shape)