diff --git a/docs/features/reset-bucket-group-order.md b/docs/features/reset-bucket-group-order.md new file mode 100644 index 0000000000000000000000000000000000000000..9fbb856e40669cf532bc4bf749b9a5fee6a191a5 --- /dev/null +++ b/docs/features/reset-bucket-group-order.md @@ -0,0 +1,31 @@ +# bucket重排算法 + +## 背景与挑战 + +在大模型的训练过程中,模型的定义顺序和执行不一致是非常常见的问题,尤其是重定义常见的transformer组件或者使用多模态大模型时。这直接导致overlap-param-gather参数时会出现精度问题和计算通信串行的问题。 +目前megatron 0.12.1的方案解决了精度问题,然而不可避免地会出现计算和通信串行的问题。 +## 解决方案 + +为解决上述问题,引入了针对于参数分桶重排的策略,通过记录第一次迭代的bucket_group顺序来实现后续迭代的计算与通信的流水掩盖,有效提升资源利用率。 + +### 在开启overlap-param-gather的条件下,打开 `--reset-bucket-group-order`参数 +在第一次迭代过程中,我们记录bucket-group触发的顺序,第一次前向结束后,将bucket的顺序记录完整。第二次迭代开始时,除了第一个bucket触发无法和计算重叠,后续每一次预取下一个桶的通信都会和当前的计算进行重叠。算法逻辑如下图所示, +

+ + +## 使用场景 + +该特性适用于采用数据并行策略的训练场景,特别适合模型定义顺序非常混乱的时候,此时桶的通信是无序的,计算和通信存在大量串行,开启overlap-param-gather效果不够显著,开启reset-bucket-group-order参数 + +## 使用方法 + +* 要启用bucket重排算法功能,需在训练配置中加入以下参数: + `--reset-bucket-group-order` +* 确保同时开启了以下两个参数。 + `--use-distributed-optimizer` + `--overlap-grad-reduce` + `--overlap-param-gather` + + + + diff --git a/mindspeed/core/distributed/distributed_data_parallel.py b/mindspeed/core/distributed/distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0332bed60f2c132d834f61759e9b75b6428150 --- /dev/null +++ b/mindspeed/core/distributed/distributed_data_parallel.py @@ -0,0 +1,58 @@ +from functools import wraps +from megatron.core.transformer.cuda_graphs import is_graph_capturing + + +def DistributedDataParallel_init_wrapper(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + res = fn(self, *args, **kwargs) + if self.ddp_config.reset_bucket_group_order: + self.bucket_group_index_list = [] + return res + return wrapper + + + +def _make_forward_pre_hook(self): + """ + Create a forward pre-hook to wait on all-gather handles when necessary (i.e., + when a module uses a parameter in a bucket with a still incomplete all-gather). + """ + + def hook(module, *unused): + assert ( + self.use_forward_hook + ), "Should use pre-hook only when overlap_param_gather is True" + + if is_graph_capturing(): + return + + # Make sure all parameters in this module have been all-gathered as necessary. + for param in module.parameters(recurse=False): + # Skip parameters without an associated buffer (such parameters have a + # .requires_grad field equal to False). + if param not in self.param_to_bucket_group: + continue + assert param.requires_grad + + if self.ddp_config.reset_bucket_group_order and self.param_to_bucket_group[ + param] not in self.bucket_group_index_list: + self.bucket_group_index_list.append(self.param_to_bucket_group[param]) + if len(self.bucket_group_index_list) == len(self.bucket_groups): + for i in range(len(self.bucket_group_index_list) - 1): + self.bucket_group_index_list[i].next_param_gather_bucket_group = self.bucket_group_index_list[ + i + 1] + + # If aligning param all-gather across pipeline stages, all-gather is dispatched + # by start_param_sync calls in core/pipeline_parallelism/schedules.py. + # If overlapping param all-gather with optimizer step, then all-gather has + # already been dispatched in optimizer step. + skip_next_bucket_dispatch = ( + self.ddp_config.align_param_gather + or self.overlap_param_gather_with_optimizer_step + ) + self.param_to_bucket_group[param].finish_param_sync( + skip_next_bucket_dispatch=skip_next_bucket_dispatch + ) + + return hook \ No newline at end of file diff --git a/mindspeed/core/distributed/distributed_data_parallel_config.py b/mindspeed/core/distributed/distributed_data_parallel_config.py new file mode 100644 index 0000000000000000000000000000000000000000..86929b8de280e7f41f74e4eb685aaba553e8f0da --- /dev/null +++ b/mindspeed/core/distributed/distributed_data_parallel_config.py @@ -0,0 +1,14 @@ +from functools import wraps +from megatron.training import get_args + + +def distributed_data_parallel_config_init_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + fn(*args, **kwargs) + self = args[0] + global_args = get_args() + reset_bucket_group_order = global_args.reset_bucket_group_order \ + if hasattr(global_args, "reset_bucket_group_order") else False + setattr(self, "reset_bucket_group_order", reset_bucket_group_order) + return wrapper \ No newline at end of file diff --git a/mindspeed/features_manager/__init__.py b/mindspeed/features_manager/__init__.py index bd602f03818fe739694666e0b15a661dd60be356..aec64190e4ccb94257d946a7b7bc33908daca4a6 100644 --- a/mindspeed/features_manager/__init__.py +++ b/mindspeed/features_manager/__init__.py @@ -75,6 +75,8 @@ from mindspeed.features_manager.dist_train.dist_train_feature import DistTrainFe from mindspeed.features_manager.tokenizer.build_tokenizer import BuildTokenizerFeature from mindspeed.features_manager.distributed.buffer_pad import BufferPadFeature from mindspeed.features_manager.distributed.torch_fully_sharded_data_parallel import TorchFullyShardedDataParallelFeature +from mindspeed.features_manager.distributed.reset_bucket_group_order_feature import ResetBucketGroupOrderFeature + from mindspeed.features_manager.custom_fsdp.custom_fsdp_feature import CustomFSDPFeature from mindspeed.features_manager.tensor_parallel.tp_2d import TP2dFeature from mindspeed.features_manager.compress_dense.compress_dense import AnsCompressTensorFeature @@ -246,7 +248,8 @@ def add_distributed_features(features_list: List[MindSpeedFeature]): BufferPadFeature(), LayerZeroFeature(), TorchFullyShardedDataParallelFeature(), - CustomFSDPFeature() + CustomFSDPFeature(), + ResetBucketGroupOrderFeature() ]) diff --git a/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py b/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py new file mode 100644 index 0000000000000000000000000000000000000000..05f940e336e5065e07752ec21bd5f7ac6ba41373 --- /dev/null +++ b/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py @@ -0,0 +1,42 @@ +import time +from argparse import ArgumentParser + +from mindspeed.features_manager.feature import MindSpeedFeature + + +class ResetBucketGroupOrderFeature(MindSpeedFeature): + + def __init__(self): + super().__init__('reset-bucket-group-order', 2) + + def register_args(self, parser: ArgumentParser): + group = parser.add_argument_group(title=self.feature_name) + group.add_argument('--reset-bucket-group-order', + action='store_true', default=False, + help='If true, forward compute with right overlap param all-gather order.') + + def validate_args(self, args): + reset_bucket_group_order = getattr(args, "reset_bucket_group_order", False) + overlap_param_gather = getattr(args, "overlap_param_gather", False) + if reset_bucket_group_order and not overlap_param_gather: + raise AssertionError('overlap param gather is compatible with reset bucket group order') + + def register_patches(self, patch_manager, args): + if getattr(args, self.feature_name, None): + reset_bucket_group_order = getattr(args, "reset_bucket_group_order", False) + if reset_bucket_group_order: + from mindspeed.core.distributed.distributed_data_parallel_config import \ + distributed_data_parallel_config_init_wrapper + patch_manager.register_patch( + 'megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig.__init__', + distributed_data_parallel_config_init_wrapper) + + from mindspeed.core.distributed.distributed_data_parallel import DistributedDataParallel_init_wrapper + patch_manager.register_patch( + 'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.DistributedDataParallel.__init__', + DistributedDataParallel_init_wrapper) + + from mindspeed.core.distributed.distributed_data_parallel import _make_forward_pre_hook + patch_manager.register_patch( + 'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_forward_pre_hook', + _make_forward_pre_hook) \ No newline at end of file