diff --git a/docs/features/reset-bucket-group-order.md b/docs/features/reset-bucket-group-order.md
new file mode 100644
index 0000000000000000000000000000000000000000..9fbb856e40669cf532bc4bf749b9a5fee6a191a5
--- /dev/null
+++ b/docs/features/reset-bucket-group-order.md
@@ -0,0 +1,31 @@
+# bucket重排算法
+
+## 背景与挑战
+
+在大模型的训练过程中,模型的定义顺序和执行不一致是非常常见的问题,尤其是重定义常见的transformer组件或者使用多模态大模型时。这直接导致overlap-param-gather参数时会出现精度问题和计算通信串行的问题。
+目前megatron 0.12.1的方案解决了精度问题,然而不可避免地会出现计算和通信串行的问题。
+## 解决方案
+
+为解决上述问题,引入了针对于参数分桶重排的策略,通过记录第一次迭代的bucket_group顺序来实现后续迭代的计算与通信的流水掩盖,有效提升资源利用率。
+
+### 在开启overlap-param-gather的条件下,打开 `--reset-bucket-group-order`参数
+在第一次迭代过程中,我们记录bucket-group触发的顺序,第一次前向结束后,将bucket的顺序记录完整。第二次迭代开始时,除了第一个bucket触发无法和计算重叠,后续每一次预取下一个桶的通信都会和当前的计算进行重叠。算法逻辑如下图所示,
+

+
+
+## 使用场景
+
+该特性适用于采用数据并行策略的训练场景,特别适合模型定义顺序非常混乱的时候,此时桶的通信是无序的,计算和通信存在大量串行,开启overlap-param-gather效果不够显著,开启reset-bucket-group-order参数
+
+## 使用方法
+
+* 要启用bucket重排算法功能,需在训练配置中加入以下参数:
+ `--reset-bucket-group-order`
+* 确保同时开启了以下两个参数。
+ `--use-distributed-optimizer`
+ `--overlap-grad-reduce`
+ `--overlap-param-gather`
+
+
+
+
diff --git a/mindspeed/core/distributed/distributed_data_parallel.py b/mindspeed/core/distributed/distributed_data_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d0332bed60f2c132d834f61759e9b75b6428150
--- /dev/null
+++ b/mindspeed/core/distributed/distributed_data_parallel.py
@@ -0,0 +1,58 @@
+from functools import wraps
+from megatron.core.transformer.cuda_graphs import is_graph_capturing
+
+
+def DistributedDataParallel_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ res = fn(self, *args, **kwargs)
+ if self.ddp_config.reset_bucket_group_order:
+ self.bucket_group_index_list = []
+ return res
+ return wrapper
+
+
+
+def _make_forward_pre_hook(self):
+ """
+ Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
+ when a module uses a parameter in a bucket with a still incomplete all-gather).
+ """
+
+ def hook(module, *unused):
+ assert (
+ self.use_forward_hook
+ ), "Should use pre-hook only when overlap_param_gather is True"
+
+ if is_graph_capturing():
+ return
+
+ # Make sure all parameters in this module have been all-gathered as necessary.
+ for param in module.parameters(recurse=False):
+ # Skip parameters without an associated buffer (such parameters have a
+ # .requires_grad field equal to False).
+ if param not in self.param_to_bucket_group:
+ continue
+ assert param.requires_grad
+
+ if self.ddp_config.reset_bucket_group_order and self.param_to_bucket_group[
+ param] not in self.bucket_group_index_list:
+ self.bucket_group_index_list.append(self.param_to_bucket_group[param])
+ if len(self.bucket_group_index_list) == len(self.bucket_groups):
+ for i in range(len(self.bucket_group_index_list) - 1):
+ self.bucket_group_index_list[i].next_param_gather_bucket_group = self.bucket_group_index_list[
+ i + 1]
+
+ # If aligning param all-gather across pipeline stages, all-gather is dispatched
+ # by start_param_sync calls in core/pipeline_parallelism/schedules.py.
+ # If overlapping param all-gather with optimizer step, then all-gather has
+ # already been dispatched in optimizer step.
+ skip_next_bucket_dispatch = (
+ self.ddp_config.align_param_gather
+ or self.overlap_param_gather_with_optimizer_step
+ )
+ self.param_to_bucket_group[param].finish_param_sync(
+ skip_next_bucket_dispatch=skip_next_bucket_dispatch
+ )
+
+ return hook
\ No newline at end of file
diff --git a/mindspeed/core/distributed/distributed_data_parallel_config.py b/mindspeed/core/distributed/distributed_data_parallel_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..86929b8de280e7f41f74e4eb685aaba553e8f0da
--- /dev/null
+++ b/mindspeed/core/distributed/distributed_data_parallel_config.py
@@ -0,0 +1,14 @@
+from functools import wraps
+from megatron.training import get_args
+
+
+def distributed_data_parallel_config_init_wrapper(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ fn(*args, **kwargs)
+ self = args[0]
+ global_args = get_args()
+ reset_bucket_group_order = global_args.reset_bucket_group_order \
+ if hasattr(global_args, "reset_bucket_group_order") else False
+ setattr(self, "reset_bucket_group_order", reset_bucket_group_order)
+ return wrapper
\ No newline at end of file
diff --git a/mindspeed/features_manager/__init__.py b/mindspeed/features_manager/__init__.py
index bd602f03818fe739694666e0b15a661dd60be356..aec64190e4ccb94257d946a7b7bc33908daca4a6 100644
--- a/mindspeed/features_manager/__init__.py
+++ b/mindspeed/features_manager/__init__.py
@@ -75,6 +75,8 @@ from mindspeed.features_manager.dist_train.dist_train_feature import DistTrainFe
from mindspeed.features_manager.tokenizer.build_tokenizer import BuildTokenizerFeature
from mindspeed.features_manager.distributed.buffer_pad import BufferPadFeature
from mindspeed.features_manager.distributed.torch_fully_sharded_data_parallel import TorchFullyShardedDataParallelFeature
+from mindspeed.features_manager.distributed.reset_bucket_group_order_feature import ResetBucketGroupOrderFeature
+
from mindspeed.features_manager.custom_fsdp.custom_fsdp_feature import CustomFSDPFeature
from mindspeed.features_manager.tensor_parallel.tp_2d import TP2dFeature
from mindspeed.features_manager.compress_dense.compress_dense import AnsCompressTensorFeature
@@ -246,7 +248,8 @@ def add_distributed_features(features_list: List[MindSpeedFeature]):
BufferPadFeature(),
LayerZeroFeature(),
TorchFullyShardedDataParallelFeature(),
- CustomFSDPFeature()
+ CustomFSDPFeature(),
+ ResetBucketGroupOrderFeature()
])
diff --git a/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py b/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py
new file mode 100644
index 0000000000000000000000000000000000000000..05f940e336e5065e07752ec21bd5f7ac6ba41373
--- /dev/null
+++ b/mindspeed/features_manager/distributed/reset_bucket_group_order_feature.py
@@ -0,0 +1,42 @@
+import time
+from argparse import ArgumentParser
+
+from mindspeed.features_manager.feature import MindSpeedFeature
+
+
+class ResetBucketGroupOrderFeature(MindSpeedFeature):
+
+ def __init__(self):
+ super().__init__('reset-bucket-group-order', 2)
+
+ def register_args(self, parser: ArgumentParser):
+ group = parser.add_argument_group(title=self.feature_name)
+ group.add_argument('--reset-bucket-group-order',
+ action='store_true', default=False,
+ help='If true, forward compute with right overlap param all-gather order.')
+
+ def validate_args(self, args):
+ reset_bucket_group_order = getattr(args, "reset_bucket_group_order", False)
+ overlap_param_gather = getattr(args, "overlap_param_gather", False)
+ if reset_bucket_group_order and not overlap_param_gather:
+ raise AssertionError('overlap param gather is compatible with reset bucket group order')
+
+ def register_patches(self, patch_manager, args):
+ if getattr(args, self.feature_name, None):
+ reset_bucket_group_order = getattr(args, "reset_bucket_group_order", False)
+ if reset_bucket_group_order:
+ from mindspeed.core.distributed.distributed_data_parallel_config import \
+ distributed_data_parallel_config_init_wrapper
+ patch_manager.register_patch(
+ 'megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig.__init__',
+ distributed_data_parallel_config_init_wrapper)
+
+ from mindspeed.core.distributed.distributed_data_parallel import DistributedDataParallel_init_wrapper
+ patch_manager.register_patch(
+ 'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.DistributedDataParallel.__init__',
+ DistributedDataParallel_init_wrapper)
+
+ from mindspeed.core.distributed.distributed_data_parallel import _make_forward_pre_hook
+ patch_manager.register_patch(
+ 'megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_forward_pre_hook',
+ _make_forward_pre_hook)
\ No newline at end of file