diff --git a/docs/features/dynamic_dpcp.md b/docs/features/dynamic_dpcp.md new file mode 100644 index 0000000000000000000000000000000000000000..1fb543a2e6b1f9b7f8ba58d6b5d83035fd21068f --- /dev/null +++ b/docs/features/dynamic_dpcp.md @@ -0,0 +1,26 @@ +## Dynamic DP/CP switch + +## 问题分析 +在动态数据集训练过程中,每轮迭代的输入数据尺寸有长有短,对于较长的序列容易出现OOM,当前MM框架使用静态并行策略,在初始时根据数据集最长尺寸确定DP/CP的并行度从而保证长序列可训。然而,该静态方法容易导致性能次优,在长短序列混合的场景下,为了适应长序列,CP往往会设置的比较大,此时短序列也需要进行CP并行,降低了短序列的计算效率。在长序列较少,短序列较多,且长短序列差距较大的场景下,性能下降程度较大。 +## 解决方案 +在每轮迭代时,根据当前每个DP获得的数据,确定当前轮应该采用的并行策略,即DPxCPy,并同时进行数据分发和并行组切换。从而既能满足长序列可训,又能不影响短序列的计算效率 + +- 并行组初始化:在megatron初始化之后,使用用户配置的最大CP并行度自动生成并行组列表,e.g. CP=4则并行组可为{dp4cp1,dp2cp2,dp1cp4} + +- 新并行策略获取:开启动态DPCP后,数据dataloader默认按照DPnCP1的方式获取数据,每轮迭代时gather所有DP域的数据,遍历并行组,得到满足序列长度要求 +的最佳并行策略(DP优先) + +- 并行域切换与数据分发:获取新并行策略后,修改并行组的全局变量,指向新的并行组。同时从DP到CP切换过程中,需要进行数据广播,被广播的卡需要缓存本卡已经获得的样本,在下一轮迭代时,优先从缓存区获取样本数据。 + + +## 适配版本 +opensoraplan1.3 + + +## 使用方法 +参数位置:pretrain_t2v_model.json +| 参数名 | 参数含义 | +| --------------------------- | -------------------------------------------------- | +| --use-dynamic-dpcp | 动态DPCP特性开关 | +| --max-cp-size | 集群支持最大的CP并行度, e.g. CP=4则并行组可为{dp4cp1,dp2cp2,dp1cp4}| +| --max-seq-size | 设定单卡可以计算的最大序列长度,e.g.最长视频为(23x1080x720)(fps,h,w)| diff --git a/examples/opensoraplan1.3/t2v/pretrain_t2v_model.json b/examples/opensoraplan1.3/t2v/pretrain_t2v_model.json index 943a887a3eb15b1e7a519da00caac6c4cf6b0e77..6cd4ef00f4da748921c083c8063b3354dbb3a8c3 100644 --- a/examples/opensoraplan1.3/t2v/pretrain_t2v_model.json +++ b/examples/opensoraplan1.3/t2v/pretrain_t2v_model.json @@ -67,6 +67,9 @@ }, "patch": { "ae_float32": true - } + }, + "use_dynamic_dpcp": false, + "max_cp_size": 4, + "max_seq_size": 54853632 } diff --git a/mindspeed_mm/models/common/attention.py b/mindspeed_mm/models/common/attention.py index e57a3b977739eb1500bf67c8dc764ccdc87b0c96..686bf94fe8ba0ad47a9cf5324621f48370df9209 100644 --- a/mindspeed_mm/models/common/attention.py +++ b/mindspeed_mm/models/common/attention.py @@ -375,7 +375,7 @@ class ParallelAttention(nn.Module): fa_layout=self.fa_layout, softmax_scale=1 / math.sqrt(self.head_dim) ) - if self.cp_size > 1 and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]: + if mpu.get_context_parallel_world_size() > 1 and args.context_parallel_algo in ["ulysses_cp_algo", "hybrid_cp_algo"]: ulysses_group = mpu.get_context_parallel_group() if args.context_parallel_algo == "hybrid_cp_algo": ulysses_group = get_context_parallel_group_for_hybrid_ulysses() @@ -552,7 +552,7 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): self.sparse_group = sparse_group if args.context_parallel_algo == 'ulysses_cp_algo': - self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, self.cp_size) + self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, mpu.get_context_parallel_world_size()) elif args.context_parallel_algo == 'hybrid_cp_algo': self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, args.ulysses_degree_in_cp) else: @@ -643,11 +643,11 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): q, k, v = super().function_before_core_attention(query, key, input_layout, rotary_pos_emb=rotary_pos_emb) total_frames = frames - if self.cp_size > 1: + if mpu.get_context_parallel_world_size() > 1: if args.context_parallel_algo == 'ulysses_cp_algo': cp_group = mpu.get_context_parallel_group() - total_frames = frames * self.cp_size + total_frames = frames * mpu.get_context_parallel_world_size() # apply all_to_all to gather sequence and split attention heads [s // sp, b, h, d] -> [s, b, h // sp, d] q = mapping.all_to_all(q, cp_group, scatter_dim=2, gather_dim=0) k = mapping.all_to_all(k, cp_group, scatter_dim=2, gather_dim=0) @@ -661,6 +661,7 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): v = mapping.all_to_all(v, cp_group, scatter_dim=2, gather_dim=0) batch_size = q.shape[1] + self.num_attention_heads_per_partition_per_cp = core.utils.divide(self.num_attention_heads_per_partition, mpu.get_context_parallel_world_size()) q = q.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim) k = k.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim) v = v.view(-1, batch_size, self.num_attention_heads_per_partition_per_cp * self.head_dim) @@ -694,11 +695,11 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): ): args = get_args() total_frames = frames - if self.cp_size > 1: + if mpu.get_context_parallel_world_size() > 1: if args.context_parallel_algo == 'ulysses_cp_algo': cp_group = mpu.get_context_parallel_group() - total_frames = frames * self.cp_size + total_frames = frames * mpu.get_context_parallel_world_size() if args.context_parallel_algo == 'hybrid_cp_algo': cp_group = get_context_parallel_group_for_hybrid_ulysses() @@ -707,7 +708,7 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): if self.sparse1d: out = self._reverse_sparse_1d(out, total_frames, height, width) - if self.cp_size > 1: + if mpu.get_context_parallel_world_size() > 1: if args.context_parallel_algo == 'ulysses_cp_algo': cp_group = mpu.get_context_parallel_group() out = mapping.all_to_all(out, cp_group, scatter_dim=0, gather_dim=2) @@ -732,7 +733,7 @@ class MultiHeadSparseAttentionSBH(ParallelAttention): if mask is not None and args.context_parallel_algo not in ['megatron_cp_algo', 'hybrid_cp_algo']: mask = mask.view(batch_size, 1, -1, mask.shape[-1]) - if self.cp_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']: + if mpu.get_context_parallel_world_size() > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo']: scale = 1.0 / math.sqrt(self.head_dim) head_num = self.num_attention_heads_per_partition_per_cp cp_para = self.cp_para diff --git a/mindspeed_mm/training.py b/mindspeed_mm/training.py index 927aa003fcfdda127f81a129c3de60202a8e83d7..957c11c2ffcdb3983ac6522f2413c85645d9c999 100644 --- a/mindspeed_mm/training.py +++ b/mindspeed_mm/training.py @@ -119,6 +119,12 @@ def pretrain( ) args = get_args() + from mindspeed_mm.utils.dpcp_utils import initialize_parall_switch_list + from datetime import timedelta + timeout = timedelta(minutes=10) + if hasattr(args.mm.model, "use_dynamic_dpcp") and args.mm.model.use_dynamic_dpcp: + initialize_parall_switch_list(timeout) + print_rank_0("dynamic dpcp is enabled") merge_mm_args(args) if args.log_throughput: diff --git a/mindspeed_mm/utils/dpcp_utils.py b/mindspeed_mm/utils/dpcp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5222f95692ea7013bc3b036290fea60dfd6b708 --- /dev/null +++ b/mindspeed_mm/utils/dpcp_utils.py @@ -0,0 +1,196 @@ + +import os +import importlib +import copy +from einops import rearrange +import torch +import torch.distributed +import numpy as np +import megatron.core.parallel_state as global_mpu +from megatron.training import get_args, print_rank_0 +from mindspeed_mm.data.data_utils.constants import ( + VIDEO, + PROMPT_IDS, + PROMPT_MASK, + VIDEO_MASK +) + +_PARALLEL_STRATEGY_LIST = [] +_PARALLEL_STRATEGY_GROUP = {} +CACHED_BATCH = [] + + +def deep_copy_batch(batch_data): + if isinstance(batch_data, torch.Tensor): + return batch_data.clone().detach() + elif isinstance(batch_data, dict): + return {key: deep_copy_batch(value) for key, value in batch_data.items()} + elif isinstance(batch_data, list): + return [deep_copy_batch(item) for item in batch_data] + elif isinstance(batch_data, tuple): + return tuple(deep_copy_batch(item) for item in batch_data) + elif isinstance(batch_data, (str, int, float, bool, type(None))): + return batch_data + else: + return copy.deepcopy(batch_data) + + +def generate_parallel_strategy_options(max_cp_size): + ori_dp_size = global_mpu.get_data_parallel_world_size() + ori_cp_size = global_mpu.get_context_parallel_world_size() + ori_dpcp_world_size = ori_dp_size * ori_cp_size + + max_cp_size = min(ori_dpcp_world_size, max_cp_size) + max_cp_power = int(max_cp_size).bit_length() - 1 + + global _PARALLEL_STRATEGY_LIST + for cp_power in range(max_cp_power + 1): + cp_size = 2**cp_power + dp_size = int(ori_dpcp_world_size // cp_size) + _PARALLEL_STRATEGY_LIST.append([dp_size, cp_size]) + + +def initialize_parall_switch_list(timeout): + args = get_args() + generate_parallel_strategy_options(args.mm.model.max_cp_size) + for index, option in enumerate(_PARALLEL_STRATEGY_LIST): + rank_generator = global_mpu.RankGenerator( + tp=global_mpu.get_tensor_model_parallel_world_size(), + ep=global_mpu.get_expert_model_parallel_world_size(), + dp=option[0], + pp=global_mpu.get_pipeline_model_parallel_world_size(), + cp=option[1], + order="tp-cp-ep-dp-pp", + ) + _PARALLEL_STRATEGY_GROUP[index] = {} + _PARALLEL_STRATEGY_GROUP[index]['dp_group'] = [] + _PARALLEL_STRATEGY_GROUP[index]['dp_group_gloo'] = [] + _PARALLEL_STRATEGY_GROUP[index]['dp'] = rank_generator.get_ranks('dp') + + for ranks in _PARALLEL_STRATEGY_GROUP[index]['dp']: + _PARALLEL_STRATEGY_GROUP[index]['dp_group'].append(torch.distributed.new_group( + ranks, timeout=timeout, pg_options=global_mpu.get_nccl_options('dp', {}) + )) + _PARALLEL_STRATEGY_GROUP[index]['dp_group_gloo'].append(torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")) + + _PARALLEL_STRATEGY_GROUP[index]['cp_group'] = [] + _PARALLEL_STRATEGY_GROUP[index]['cp'] = rank_generator.get_ranks('cp') + for ranks in _PARALLEL_STRATEGY_GROUP[index]['cp']: + _PARALLEL_STRATEGY_GROUP[index]['cp_group'].append(torch.distributed.new_group( + ranks, timeout=timeout, pg_options=global_mpu.get_nccl_options('cp', {}) + )) + + +def modify_parallel(strategy_idx): + rank = torch.distributed.get_rank() + parallel_strategy = _PARALLEL_STRATEGY_GROUP[int(strategy_idx)] + for ranks, group, group_gloo in zip(parallel_strategy['dp'], parallel_strategy['dp_group'], parallel_strategy['dp_group_gloo']): + if rank in ranks: + global_mpu._DATA_PARALLEL_GROUP = group + global_mpu._DATA_PARALLEL_GROUP_GLOO = group_gloo + global_mpu._DATA_PARALLEL_GLOBAL_RANKS = ranks + for ranks, group in zip(parallel_strategy['cp'], parallel_strategy['cp_group']): + if rank in ranks: + global_mpu._CONTEXT_PARALLEL_GROUP = group + global_mpu._CONTEXT_PARALLEL_GLOBAL_RANKS = ranks + torch.distributed.barrier() + + +def dynamic_dpcp_transfer_data(batch): + """ + 在CP组内广播拥有最长序列的batch + """ + context_group = global_mpu.get_context_parallel_group() + rank = torch.distributed.get_rank() + device = batch[VIDEO].device + + local_video_length = batch[VIDEO].shape.numel() + local_length_tensor = torch.tensor([local_video_length], dtype=torch.long, device=device) + + group_size = torch.distributed.get_world_size(group=context_group) + all_lengths = [torch.zeros_like(local_length_tensor) for _ in range(group_size)] + torch.distributed.all_gather(all_lengths, local_length_tensor, group=context_group) + + video_lengths = [length.item() for length in all_lengths] + max_video_length = max(video_lengths) + src_rank_local = video_lengths.index(max_video_length) + src_rank = torch.distributed.get_global_rank(context_group, src_rank_local) + + if rank != src_rank: + CACHED_BATCH.append(deep_copy_batch(batch)) + + for key, value in batch.items(): + if value is not None and isinstance(value, torch.Tensor): + shape = torch.tensor(list(value.shape), dtype=torch.long).to(device) + shape_size = torch.tensor([len(shape)], dtype=torch.long).to(device) + torch.distributed.broadcast(shape_size, src=src_rank, group=context_group) + if rank == src_rank: + torch.distributed.broadcast(shape, src=src_rank, group=context_group) + else: + shape = torch.zeros(shape_size.item(), dtype=torch.long).to(device) + torch.distributed.broadcast(shape, src=src_rank, group=context_group) + if rank == src_rank: + torch.distributed.broadcast(value.contiguous(), src=src_rank, group=context_group) + else: + value = torch.zeros(size=[dim.item() for dim in shape], dtype=batch[key].dtype).to(device) + torch.distributed.broadcast(value.contiguous(), src=src_rank, group=context_group) + batch[key] = value + else: + if rank == src_rank: + message = [value] + else: + message = [None] + torch.distributed.broadcast_object_list(message, src=src_rank, group=context_group) + batch[key] = message[0] + torch.distributed.barrier(group=context_group) + + +def get_optimized_parallel_strategy(input_seq): + args = get_args() + dp_group = global_mpu.get_data_parallel_group() + dp_size = global_mpu.get_data_parallel_world_size() + local_rank = torch.distributed.get_rank() + local_info = { + key: { + 'shape': tensor.shape + } for key, tensor in input_seq.items() + } + gather_rank_dst = torch.distributed.get_global_rank(dp_group, 0) + if local_rank == gather_rank_dst: + all_info = [None] * dp_size + torch.distributed.gather_object(local_info, all_info, dst=gather_rank_dst) + else: + torch.distributed.gather_object(local_info, None, dst=gather_rank_dst) + strategy_idx = torch.tensor(0).to('npu') + if local_rank == gather_rank_dst: + seq_size_list = [info[VIDEO]['shape'].numel() for info in all_info] + args = get_args() + oom_flag = True + for idx, strategy in enumerate(_PARALLEL_STRATEGY_LIST): + dp, cp = strategy + if max(seq_size_list) / cp <= args.mm.model.max_seq_size: + oom_flag = False + strategy_idx = torch.tensor(idx).to('npu') + break + if oom_flag is True: + print("max_seq_size is too small") + + torch.distributed.broadcast(strategy_idx, src=gather_rank_dst) + return strategy_idx + + +def data_aware_parallel_optimize(batch): + args = get_args() + if hasattr(args.mm.model, "use_dynamic_dpcp") and args.mm.model.use_dynamic_dpcp: + # 0. 提取序列 + input_seq = {VIDEO: batch[VIDEO]} + # 1. 初始化为全DP + modify_parallel(0) + # 2. 优化算法,输出新的并行策略以及数据重排方式 + strategy_idx = get_optimized_parallel_strategy(input_seq) + if strategy_idx > 0: + print_rank_0("adjust [dp,cp] to {}".format(_PARALLEL_STRATEGY_LIST[strategy_idx])) + # 3. 刷新并行策略 + modify_parallel(strategy_idx) + # 4. 执行数据重排 + dynamic_dpcp_transfer_data(batch) \ No newline at end of file diff --git a/pretrain_sora.py b/pretrain_sora.py index 5df06c2cdd83cf1036860aafefe250d58ce1f758..bf49bb60e169124bd6bdbac51d8fed480760b202 100644 --- a/pretrain_sora.py +++ b/pretrain_sora.py @@ -27,6 +27,7 @@ from mindspeed_mm.data.data_utils.constants import ( from mindspeed_mm.data.data_utils.utils import build_iterations from mindspeed_mm.models.sora_model import SoRAModel from mindspeed_mm.patchs import dummy_optimizer_patch +import mindspeed_mm.utils.dpcp_utils as dpcp mindspeed_args = get_mindspeed_args() if hasattr(mindspeed_args, "ai_framework") and mindspeed_args.ai_framework == "mindspore" and mindspeed_args.optimization_level >= 0: @@ -101,12 +102,25 @@ def get_batch_for_step(data_iterator): return batches, get_batch_interval +def get_dpcp_batch_for_step(data_iterator): + if dpcp.CACHED_BATCH: + batch = dpcp.CACHED_BATCH.pop(0) + else: + batch = get_batch(data_iterator) + return [batch] + + def forward_step(data_iterator, model): """Forward step.""" batch, video, prompt_ids, video_mask, prompt_mask = {}, None, None, None, None skip_encode = False + args = get_args() if mpu.is_pipeline_first_stage(): - batches, get_batch_interval = get_batch_for_step(data_iterator) + if hasattr(args.mm.model, "use_dynamic_dpcp") and args.mm.model.use_dynamic_dpcp: + batches = get_dpcp_batch_for_step(data_iterator) + get_batch_interval = 0 + else: + batches, get_batch_interval = get_batch_for_step(data_iterator) skip_encode = not batches i2v_params = defaultdict(list) # while encoder dp or encoder interleave offload is enabled. reconstruct data as list: [step_1, ... step_n] @@ -126,6 +140,7 @@ def forward_step(data_iterator, model): batch = i2v_params elif len(batches) == 1: batch = batches[0] + dpcp.data_aware_parallel_optimize(batch) video = batch.pop(VIDEO, None) prompt_ids = batch.pop(PROMPT_IDS, None) video_mask = batch.pop(VIDEO_MASK, None)