diff --git a/README.md b/README.md index 3d4ff246413854eeb3e248f64b47e75646413def..0ebdfa219ceee731172010473068b6f2fc395749 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,16 @@ MindSpeed RL是基于昇腾生态的强化学习加速框架,旨在为华为 [ Preview + + 长序列并行 + Doc + GRPO + + Qwen2.5-7B
+ Qwen2.5-32B
+ + Preview + diff --git a/docs/features/context_parallel.md b/docs/features/context_parallel.md new file mode 100644 index 0000000000000000000000000000000000000000..d8816c694fd0fabbad3442a6405f32ac5126b684 --- /dev/null +++ b/docs/features/context_parallel.md @@ -0,0 +1,27 @@ +# 长序列并行 + +## 背景介绍 +长序列训练需求日益增加,应用场景极为广泛,如翻译场景、多模态场景等等。为解决长序列导致显存溢出的问题,本仓库提供了长序列并行(Context Parallel)的解决方案。 + +## 方案介绍 +### Ulysses +[Ulysses](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses)是一种用于长序列训练的分布式并行技术,由微软 DeepSpeed 提出。其核心思想是将输入序列在序列维度上切分给不同的计算设备,并通过 All-to-All 通信方式确保每个计算设备能够计算不同注意力头的子集。这种方式可以降低激活显存,解决长序列场景下显存OOM的问题。 + +具体来说,Ulysses 将各个样本在序列维度上分割给参与的计算设备;然后,在 attention 计算之前,它对已分割的查询(Q)、键(K)和值(V)执行 all-to-all 通信操作,以使每个计算设备接收完整的序列,但仅用于注意力头的非重叠子集,这使得参与的计算设备可以并行计算不同的注意力头;最后,Ulysses 使用另一个 all-to-all 来在注意力头上收集结果,同时重新在序列维度上进行分区。 + + +## 使用介绍 + +当前仓上的Context Parallel支持ulysses切分,通过如下配置可以使能: +``` +actor_config: + context_parallel_size: 2 + context_parallel_algo: ulysses_cp_algo +``` + +其中: + +`context_parallel_size` 表示CP并行数。如果选用ulysses_cp_algo,需满足条件**模型num_attention_heads%(CP*TP)=0** + +`context_parallel_algo` 表示选用的长序列并行方法,当前仅支持**ulysses_cp_algo**;如果不配置此参数,默认取**ulysses_cp_algo**。 + diff --git a/mindspeed_rl/config_cls/megatron_config.py b/mindspeed_rl/config_cls/megatron_config.py index f27e7c3bad563ba4a32e769693edff499408a750..38a47134d87cedc531ed84c01fc7a89842e15f71 100644 --- a/mindspeed_rl/config_cls/megatron_config.py +++ b/mindspeed_rl/config_cls/megatron_config.py @@ -138,6 +138,7 @@ class MegatronConfig(BaseConfig): dataloader_type: Single pass vs multiple pass data loader (default: None) enable_high_availability: Switch of the high availability feature (default: False) context_parallel_size: Degree of context parallelism (default: 1) + context_parallel_algo: Algorithm of context parallelism (default: ulysses_cp_algo) reset_position_ids: Reset posistion ids after end-of-document token (default: False) optimizer: Optimizer function (default: 'adam') do_sample: Enable doing sample in actor generations (default: False) @@ -326,6 +327,7 @@ class MegatronConfig(BaseConfig): self.dataloader_type = None self.enable_high_availability = False self.context_parallel_size = 1 + self.context_parallel_algo = "ulysses_cp_algo" self.reset_position_ids = False self.optimizer = 'adam' self.do_sample = False diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index 70a04f35d34801a2127fdbd161661f730f59dc19..56c3f1c5d349a63bcda72be03515114fdeee0b12 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -49,6 +49,10 @@ def validate_rl_args( f"Actor.seq_length={actor_config.seq_length} vs " f"GenerateConfig.max_model_len={generate_config.max_model_len}") + if actor_config.context_parallel_size > 1 and actor_config.context_parallel_algo is not None: + if actor_config.context_parallel_algo not in ["ulysses_cp_algo"]: + raise ValueError("Now just support ulysses CP") + # 校验移除填充特性相关配置 if rl_config.use_remove_padding: if actor_config.pipeline_model_parallel_size > 1 and not actor_config.variable_seq_lengths: diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index 00c58448c366dd348e923d6aa2551cadfb3cede5..27308d7497ae6a2f6353ae0c44e7e75e5802fccc 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -14,6 +14,7 @@ from mindspeed_rl.utils.utils import ( ) from mindspeed_rl.utils.remove_padding import preprocess_packed_seqs, postprocess_packed_seqs from mindspeed_rl.utils.compute import get_parallel_state +from mindspeed_rl.utils.utils import get_batch_on_this_cp_rank class BaseTrainingEngine(ABC): @@ -56,6 +57,8 @@ class BaseTrainingEngine(ABC): set_actual_seq_len: Callable = None, forward_backward_func: Callable = None, entropy_coeff: float = 0.0, + context_parallel_algo: str = "ulysses_cp_algo", + context_parallel_size: int = 1, kl_penalty: str = "low_var_kl", **kwargs): self.forward_backward_func = forward_backward_func @@ -75,6 +78,8 @@ class BaseTrainingEngine(ABC): self.kl_penalty = kl_penalty self.clip_ratio = clip_ratio self.entropy_coeff = entropy_coeff + self.context_parallel_algo = context_parallel_algo + self.context_parallel_size = context_parallel_size self.temperature = temperature self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role) self.kwargs = kwargs @@ -102,8 +107,9 @@ class BaseTrainingEngine(ABC): data_iter = [iter(batches) for _ in self.model] self.loss_func.add_loss_meta_info(self.get_loss_meta_func()) post_process = get_parallel_state().get_pipeline_model_parallel_world_size() == 1 or get_parallel_state().is_pipeline_last_stage() - + def forward_step(batch_iter, model): + cp_size = get_parallel_state().get_context_parallel_world_size() if self.use_remove_padding: input_ids, position_ids, process_batch, seqlens_in_batch, cu_seqlens_padded = self._get_forward_batch_info(batch_iter) self.set_actual_seq_len(cu_seqlens_padded.tolist()) @@ -118,6 +124,12 @@ class BaseTrainingEngine(ABC): else: input_ids, attention_mask, position_ids, process_batch = self._get_forward_batch_info(batch_iter) output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + if post_process: + if cp_size > 1: + output_list = [torch.empty_like(output) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, output.detach(), group=get_parallel_state().get_context_parallel_group()) + output_list[get_parallel_state().get_context_parallel_rank()] = output + output = torch.cat(output_list, dim=1) output.div_(self.temperature) return output, partial(self.loss_func.compute_loss, batch=process_batch, forward_only=forward_only) @@ -146,6 +158,7 @@ class BaseTrainingEngine(ABC): input_ids = batch['input_ids'] attention_mask_1d = generate_mask(input_ids, batch['prompt_length'] + batch['response_length']).to( input_ids.device) + cp_size = get_parallel_state().get_context_parallel_world_size() if self.use_remove_padding: tp_size = get_parallel_state().get_tensor_model_parallel_world_size() input_ids, position_ids, seqlens_in_batch, cu_seqlens_padded = preprocess_packed_seqs( @@ -154,7 +167,18 @@ class BaseTrainingEngine(ABC): else: position_ids = torch.tensor(generate_position_ids(input_ids)).to(input_ids.device) attention_mask = get_tune_attention_mask(attention_mask_1d) - return input_ids, attention_mask, position_ids, batch + if cp_size > 1: + batch_for_cp = { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'position_ids': position_ids + } + batch_cp = get_batch_on_this_cp_rank(self.context_parallel_algo, self.context_parallel_size, batch_for_cp) + input_ids = batch_cp['input_ids'] + attention_mask = None + position_ids = batch_cp['position_ids'] + + return input_ids, attention_mask, position_ids, batch def post_process_forward_backward_output(self, output: [torch.Tensor], batch: Dict[str, torch.Tensor]) -> torch.Tensor: diff --git a/mindspeed_rl/trainer/utils/parallel_state.py b/mindspeed_rl/trainer/utils/parallel_state.py index 788e287f4dee50ed012576ccd648a4587703dde5..5e47fcfb9614df678257d0538a186d7dff67903d 100644 --- a/mindspeed_rl/trainer/utils/parallel_state.py +++ b/mindspeed_rl/trainer/utils/parallel_state.py @@ -41,6 +41,17 @@ def get_tensor_model_parallel_rank(mpu, use_vllm=False): return mpu.get_tensor_model_parallel_rank() +def get_context_parallel_rank(mpu, use_vllm=False): + if use_vllm: + from vllm.distributed import parallel_state as vpu + if not hasattr(vpu, "get_context_parallel_rank"): + vpu = mpu + return vpu.get_context_parallel_rank() + else: + return mpu.get_context_parallel_rank() + + + def get_tensor_model_parallel_src_rank(mpu, use_vllm=False): if use_vllm: from vllm.distributed import parallel_state as vpu @@ -49,6 +60,13 @@ def get_tensor_model_parallel_src_rank(mpu, use_vllm=False): return mpu.get_tensor_model_parallel_src_rank() +def get_context_parallel_src_rank(mpu, use_vllm=False): + if use_vllm: + raise NotImplementedError("not implememted yet.") + else: + return mpu.get_context_parallel_global_ranks()[0] + + def get_tensor_model_parallel_group(mpu, use_vllm=False): if use_vllm: from vllm.distributed import parallel_state as vpu @@ -57,10 +75,35 @@ def get_tensor_model_parallel_group(mpu, use_vllm=False): return mpu.get_tensor_model_parallel_group() +def get_context_parallel_group(mpu, use_vllm=False): + if use_vllm: + raise NotImplementedError("not implememted yet.") + else: + return mpu.get_context_parallel_group() + + def get_model_parallel_group(mpu, use_vllm=False): if use_vllm: import vllm from vllm.distributed import parallel_state as vpu return vpu.get_tensor_model_parallel_group().device_group else: - return mpu.get_model_parallel_group() \ No newline at end of file + return mpu.get_model_parallel_group() + + +def get_tensor_and_context_parallel_rank(mpu, use_vllm=False): + """Return caller's rank in the joint tensor-model-parallel and context-parallel group.""" + if use_vllm: + from vllm.distributed import parallel_state as vpu + return vpu.get_tensor_model_parallel_rank() + else: + return mpu.get_tensor_and_context_parallel_rank() + + +def get_tensor_and_context_parallel_group(mpu, use_vllm=False): + """Get the tensor- and context-parallel group the caller rank belongs to.""" + if use_vllm: + from vllm.distributed import parallel_state as vpu + return vpu.get_tensor_model_parallel_group().device_group + else: + return mpu.get_tensor_and_context_parallel_group() \ No newline at end of file diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 44d5a631aab266aee2432555c2dab056ff456074..2358dd0b6e312cd63871b694290fb20dd7f7055d 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -434,6 +434,23 @@ def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False) return getattr(model, attr) +def get_batch_on_this_cp_rank(context_parallel_algo, context_parallel_size, batch): + """ Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + from mindspeed.utils import (set_actual_seq_len, set_position_ids, + _get_batch_on_this_cp_rank_in_ulysses_cp) + + if context_parallel_size <= 1: + return batch + + if context_parallel_algo == 'ulysses_cp_algo': + batch = _get_batch_on_this_cp_rank_in_ulysses_cp(batch) + else: + raise NotImplementedError("only support ulysses_cp_algo.") + return batch + + def get_grpo_profiler(profiler_config, role: str = None): args = profiler_config if not args or not args.profile: diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 69e5cd17f97cbe71d628a1d5b4403b5022bf75ab..1acdd54a6388bb658e3d21bac1ab6663c83979d2 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -111,6 +111,8 @@ class ActorHybridWorkerBase(BaseWorker): micro_batch_size=self.megatron_config.micro_batch_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=megatron_module['set_actual_seq_len'], + context_parallel_algo=self.megatron_config.context_parallel_algo, + context_parallel_size=self.megatron_config.context_parallel_size, entropy_coeff=self.rl_config.entropy_coeff, kl_penalty=self.rl_config.kl_penalty, temperature=self.generate_config.sampling_config["temperature"] @@ -172,6 +174,8 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns, experience_count, self.megatron_config.tensor_model_parallel_size, + self.megatron_config.context_parallel_size, + self.megatron_config.context_parallel_algo, indexes=sorted_indexes.pop( 0) if self.rl_config.guarantee_order else None, get_n_samples=False) @@ -184,7 +188,7 @@ class ActorHybridWorkerBase(BaseWorker): self.args.consumed_train_samples += self.megatron_config.global_batch_size // self.rl_config.n_samples_per_prompt self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.megatron_config.global_batch_size) - if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0: + if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0 and self.parallel_state.get_context_parallel_rank() == 0: ray.get(self.td.update_metrics.remote(value=metrics, cumulate=True)) ray.get( self.td.update_metrics.remote( @@ -243,6 +247,8 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + cp_size=self.megatron_config.context_parallel_size, + cp_algo=self.megatron_config.context_parallel_algo, indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, use_vllm=True ) @@ -322,6 +328,8 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + cp_size=self.megatron_config.context_parallel_size, + cp_algo=self.megatron_config.context_parallel_algo, indexes=sorted_indexes.pop( 0) if self.rl_config.guarantee_order else None, get_n_samples=False) diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index 51d91a3624a4c8da89cb922973f033d48c57a411..d5aa486f62fc90d992a1c18986b6b407028e0a28 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -10,6 +10,7 @@ import socket import torch import torch_npu import ray +import torch.distributed as dist from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import get_vllm_tp_group_ranks from mindspeed_rl.utils.loggers import Loggers @@ -26,8 +27,11 @@ from mindspeed_rl.trainer.utils.parallel_state import ( is_pipeline_last_stage, get_tensor_model_parallel_group, get_tensor_model_parallel_rank, + get_context_parallel_rank, get_tensor_model_parallel_src_rank, - get_model_parallel_group + get_model_parallel_group, + get_context_parallel_src_rank, + get_context_parallel_group ) from mindspeed_rl.utils.compute import set_parallel_state, set_vocab_parallel from mindspeed_rl.utils.utils import get_current_dp_range_indexes @@ -168,12 +172,24 @@ class BaseWorker(BaseRayWorker, ABC): else: current_device = next(self.model[0].parameters()).device status = torch.tensor(0, device=current_device) - if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ - get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: + + rank_flg = False + if not use_vllm: + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) + else: + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) + if rank_flg: status = torch.tensor(int(not ray.get(self.td.all_consumed.remote(experience_consumer_stage))), device=current_device) torch.distributed.all_reduce(status, group=get_model_parallel_group(self.parallel_state, use_vllm), op=torch.distributed.ReduceOp.MAX) + if not use_vllm: + torch.distributed.all_reduce(status, group=get_context_parallel_group(self.parallel_state, use_vllm), + op=torch.distributed.ReduceOp.MAX) + return status def setup_distributed_rank(self): @@ -223,17 +239,26 @@ class BaseWorker(BaseRayWorker, ABC): @mstx_timer_decorator def dispatch_transfer_dock_data(self, experience_consumer_stage, - experience_columns, experience_count, tp_size=1, + experience_columns, experience_count, tp_size=1, cp_size=1, cp_algo=None, use_vllm=False, indexes=None, get_n_samples=True): pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod batch_data = {} batch_data_length = {} - # make sure that all ranks in tp/pp group enter dispatch_transfer_dock_data, + # make sure that all ranks in cp/tp/pp group enter dispatch_transfer_dock_data, # in case of rank0 get_experience before other ranks judge td.all_consumed - if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ - get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: + + rank_flg = False + if not use_vllm: + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) + else: + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) + + if rank_flg: batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count, indexes=indexes, get_n_samples=get_n_samples)) # cpu数据 @@ -244,11 +269,17 @@ class BaseWorker(BaseRayWorker, ABC): else: index = torch.empty(experience_count, device=torch.cuda.current_device(), dtype=torch.int64) - # # 传输index, 并判断是否取出了数据 torch.distributed.broadcast( index, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) ) + + if not use_vllm: + torch.distributed.broadcast( + index, get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm) + ) + torch.distributed.broadcast( index, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) @@ -257,13 +288,11 @@ class BaseWorker(BaseRayWorker, ABC): if index[0].item() == -1: return None, None - if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ - get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: + if rank_flg: batch_data, batch_data_length = pack_experience_columns(batch_data, experience_count) for key in experience_columns: - if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \ - get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0: + if rank_flg: batch_data_shape = torch.tensor(batch_data[key].shape, dtype=torch.int64, device=torch.cuda.current_device()) @@ -280,7 +309,7 @@ class BaseWorker(BaseRayWorker, ABC): batch_data_dtype = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) batch_data_length_shape = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) - # 传输tensor数据形状和类型 + # TP domain sync torch.distributed.broadcast( batch_data_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) @@ -289,6 +318,23 @@ class BaseWorker(BaseRayWorker, ABC): batch_data_dtype, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) ) + torch.distributed.broadcast(batch_data_length_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), + group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)) + + # CP domain sync + if not use_vllm: + torch.distributed.broadcast( + batch_data_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm) + ) + torch.distributed.broadcast( + batch_data_dtype, get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm) + ) + torch.distributed.broadcast(batch_data_length_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm)) + + # PP domain sync torch.distributed.broadcast( batch_data_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) @@ -297,14 +343,11 @@ class BaseWorker(BaseRayWorker, ABC): batch_data_dtype, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) ) - - torch.distributed.broadcast(batch_data_length_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)) torch.distributed.broadcast(batch_data_length_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)) - if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) != 0 or \ - get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) != 0: + + if not rank_flg: if batch_data_dtype == 1: batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1], device=torch.cuda.current_device(), @@ -320,6 +363,13 @@ class BaseWorker(BaseRayWorker, ABC): batch_data[key].cuda(), get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) ) + + if not use_vllm: + torch.distributed.broadcast( + batch_data[key].cuda(), get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm) + ) + torch.distributed.broadcast( batch_data[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) @@ -328,13 +378,19 @@ class BaseWorker(BaseRayWorker, ABC): torch.distributed.broadcast(batch_data_length[key].cuda(), get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)) + if not use_vllm: + torch.distributed.broadcast(batch_data_length[key].cuda(), get_context_parallel_src_rank(self.parallel_state, use_vllm), + group=get_context_parallel_group(self.parallel_state, use_vllm)) + torch.distributed.broadcast(batch_data_length[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)) index_without_pad = index.cpu().numpy().tolist()[:batch_data_shape[0]] if batch_data: - padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, tp_size) + if cp_algo in ['ulysses_cp_algo']: + padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, tp_size * cp_size) + return padded_batch_data, index_without_pad else: return {}, [] diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index a7035ffd12c8c79be3e20d23986e2e658dabb0a6..a0ae627bc7d24778b3312b8f905a54452988619b 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -103,6 +103,8 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, + context_parallel_algo=self.megatron_config.context_parallel_algo, + context_parallel_size=self.megatron_config.context_parallel_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=megatron_module['set_actual_seq_len'], temperature=self.generate_config.sampling_config["temperature"] diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 04b87738ddc55cc34bac8b3bb9e826d6f00a81b3..fc5423777854caed3174a14870ad398178c8d6ff 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -15,9 +15,12 @@ from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.utils.compute import get_parallel_state -from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank +from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank, get_context_parallel_rank +from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.utils import mstx_timer_decorator +logger = Loggers(__name__) + class ReferenceWorkerBase(BaseWorker): """ @@ -80,6 +83,8 @@ class ReferenceWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, + context_parallel_algo=self.megatron_config.context_parallel_algo, + context_parallel_size=self.megatron_config.context_parallel_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=megatron_module['set_actual_seq_len'], temperature=self.generate_config.sampling_config["temperature"] @@ -102,10 +107,12 @@ class ReferenceWorkerBase(BaseWorker): experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + cp_size=self.megatron_config.context_parallel_size, + cp_algo=self.megatron_config.context_parallel_algo, indexes=sorted_indexes.pop( 0) if self.rl_config.guarantee_order else None, get_n_samples=False) - + if not start_time_defined: start_time = time.time() start_time_defined = True @@ -137,7 +144,7 @@ class ReferenceWorkerBase(BaseWorker): parallel_state = get_parallel_state() use_vllm = False - if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: + if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0: ref_end_time = time.time() ray.get( self.td.update_metrics.remote( diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 182b695f118e1345ad77adc9345017b46abd340e..eebfb95df5eb5aa66f1389220f06ffce115be58e 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -14,7 +14,7 @@ from mindspeed_rl.trainer.utils.compute_utils import get_last_reward from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.utils.compute import get_parallel_state -from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank +from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank, get_context_parallel_rank class RewardWorkerBase(BaseWorker): @@ -74,6 +74,8 @@ class RewardWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, + context_parallel_algo=self.megatron_config.context_parallel_algo, + context_parallel_size=self.megatron_config.context_parallel_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=megatron_module['set_actual_seq_len'], temperature=self.generate_config.sampling_config["temperature"] @@ -104,6 +106,8 @@ class RewardWorkerBase(BaseWorker): experience_columns, experience_count, tp_size=self.megatron_config.tensor_model_parallel_size, + cp_size=self.megatron_config.context_parallel_size, + cp_algo=self.megatron_config.context_parallel_algo, indexes=sorted_indexes.pop( 0) if self.rl_config.guarantee_order else None, ) @@ -139,7 +143,7 @@ class RewardWorkerBase(BaseWorker): ) parallel_state = get_parallel_state() use_vllm = False - if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: + if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0: rwd_end_time = time.time() ray.get( self.td.update_metrics.remote( diff --git a/tests/ut/workers/test_base_worker.py b/tests/ut/workers/test_base_worker.py index f6bc36e976664f9d9beae3cfdaae9e5fa28f01dd..28aca380450f06b253b7863471a44441fe7a1922 100644 --- a/tests/ut/workers/test_base_worker.py +++ b/tests/ut/workers/test_base_worker.py @@ -148,4 +148,4 @@ class TestBaseWorker(DistributedTest): _, _ = worker.dispatch_transfer_dock_data(experience_consumer_stage, experience_columns, experience_count) - assert mock_broadcast.call_count == 2 + assert mock_broadcast.call_count == 3