diff --git a/README.md b/README.md
index 3d4ff246413854eeb3e248f64b47e75646413def..0ebdfa219ceee731172010473068b6f2fc395749 100644
--- a/README.md
+++ b/README.md
@@ -148,6 +148,16 @@ MindSpeed RL是基于昇腾生态的强化学习加速框架,旨在为华为 [
Preview |
+
+ 长序列并行 |
+ Doc |
+ GRPO |
+
+ Qwen2.5-7B
+ Qwen2.5-32B
+ |
+ Preview |
+
diff --git a/docs/features/context_parallel.md b/docs/features/context_parallel.md
new file mode 100644
index 0000000000000000000000000000000000000000..d8816c694fd0fabbad3442a6405f32ac5126b684
--- /dev/null
+++ b/docs/features/context_parallel.md
@@ -0,0 +1,27 @@
+# 长序列并行
+
+## 背景介绍
+长序列训练需求日益增加,应用场景极为广泛,如翻译场景、多模态场景等等。为解决长序列导致显存溢出的问题,本仓库提供了长序列并行(Context Parallel)的解决方案。
+
+## 方案介绍
+### Ulysses
+[Ulysses](https://github.com/deepspeedai/DeepSpeed/tree/master/blogs/deepspeed-ulysses)是一种用于长序列训练的分布式并行技术,由微软 DeepSpeed 提出。其核心思想是将输入序列在序列维度上切分给不同的计算设备,并通过 All-to-All 通信方式确保每个计算设备能够计算不同注意力头的子集。这种方式可以降低激活显存,解决长序列场景下显存OOM的问题。
+
+具体来说,Ulysses 将各个样本在序列维度上分割给参与的计算设备;然后,在 attention 计算之前,它对已分割的查询(Q)、键(K)和值(V)执行 all-to-all 通信操作,以使每个计算设备接收完整的序列,但仅用于注意力头的非重叠子集,这使得参与的计算设备可以并行计算不同的注意力头;最后,Ulysses 使用另一个 all-to-all 来在注意力头上收集结果,同时重新在序列维度上进行分区。
+
+
+## 使用介绍
+
+当前仓上的Context Parallel支持ulysses切分,通过如下配置可以使能:
+```
+actor_config:
+ context_parallel_size: 2
+ context_parallel_algo: ulysses_cp_algo
+```
+
+其中:
+
+`context_parallel_size` 表示CP并行数。如果选用ulysses_cp_algo,需满足条件**模型num_attention_heads%(CP*TP)=0**
+
+`context_parallel_algo` 表示选用的长序列并行方法,当前仅支持**ulysses_cp_algo**;如果不配置此参数,默认取**ulysses_cp_algo**。
+
diff --git a/mindspeed_rl/config_cls/megatron_config.py b/mindspeed_rl/config_cls/megatron_config.py
index f27e7c3bad563ba4a32e769693edff499408a750..38a47134d87cedc531ed84c01fc7a89842e15f71 100644
--- a/mindspeed_rl/config_cls/megatron_config.py
+++ b/mindspeed_rl/config_cls/megatron_config.py
@@ -138,6 +138,7 @@ class MegatronConfig(BaseConfig):
dataloader_type: Single pass vs multiple pass data loader (default: None)
enable_high_availability: Switch of the high availability feature (default: False)
context_parallel_size: Degree of context parallelism (default: 1)
+ context_parallel_algo: Algorithm of context parallelism (default: ulysses_cp_algo)
reset_position_ids: Reset posistion ids after end-of-document token (default: False)
optimizer: Optimizer function (default: 'adam')
do_sample: Enable doing sample in actor generations (default: False)
@@ -326,6 +327,7 @@ class MegatronConfig(BaseConfig):
self.dataloader_type = None
self.enable_high_availability = False
self.context_parallel_size = 1
+ self.context_parallel_algo = "ulysses_cp_algo"
self.reset_position_ids = False
self.optimizer = 'adam'
self.do_sample = False
diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py
index 70a04f35d34801a2127fdbd161661f730f59dc19..56c3f1c5d349a63bcda72be03515114fdeee0b12 100644
--- a/mindspeed_rl/config_cls/validate_config.py
+++ b/mindspeed_rl/config_cls/validate_config.py
@@ -49,6 +49,10 @@ def validate_rl_args(
f"Actor.seq_length={actor_config.seq_length} vs "
f"GenerateConfig.max_model_len={generate_config.max_model_len}")
+ if actor_config.context_parallel_size > 1 and actor_config.context_parallel_algo is not None:
+ if actor_config.context_parallel_algo not in ["ulysses_cp_algo"]:
+ raise ValueError("Now just support ulysses CP")
+
# 校验移除填充特性相关配置
if rl_config.use_remove_padding:
if actor_config.pipeline_model_parallel_size > 1 and not actor_config.variable_seq_lengths:
diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py
index 00c58448c366dd348e923d6aa2551cadfb3cede5..27308d7497ae6a2f6353ae0c44e7e75e5802fccc 100644
--- a/mindspeed_rl/models/base/base_training_engine.py
+++ b/mindspeed_rl/models/base/base_training_engine.py
@@ -14,6 +14,7 @@ from mindspeed_rl.utils.utils import (
)
from mindspeed_rl.utils.remove_padding import preprocess_packed_seqs, postprocess_packed_seqs
from mindspeed_rl.utils.compute import get_parallel_state
+from mindspeed_rl.utils.utils import get_batch_on_this_cp_rank
class BaseTrainingEngine(ABC):
@@ -56,6 +57,8 @@ class BaseTrainingEngine(ABC):
set_actual_seq_len: Callable = None,
forward_backward_func: Callable = None,
entropy_coeff: float = 0.0,
+ context_parallel_algo: str = "ulysses_cp_algo",
+ context_parallel_size: int = 1,
kl_penalty: str = "low_var_kl",
**kwargs):
self.forward_backward_func = forward_backward_func
@@ -75,6 +78,8 @@ class BaseTrainingEngine(ABC):
self.kl_penalty = kl_penalty
self.clip_ratio = clip_ratio
self.entropy_coeff = entropy_coeff
+ self.context_parallel_algo = context_parallel_algo
+ self.context_parallel_size = context_parallel_size
self.temperature = temperature
self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role)
self.kwargs = kwargs
@@ -102,8 +107,9 @@ class BaseTrainingEngine(ABC):
data_iter = [iter(batches) for _ in self.model]
self.loss_func.add_loss_meta_info(self.get_loss_meta_func())
post_process = get_parallel_state().get_pipeline_model_parallel_world_size() == 1 or get_parallel_state().is_pipeline_last_stage()
-
+
def forward_step(batch_iter, model):
+ cp_size = get_parallel_state().get_context_parallel_world_size()
if self.use_remove_padding:
input_ids, position_ids, process_batch, seqlens_in_batch, cu_seqlens_padded = self._get_forward_batch_info(batch_iter)
self.set_actual_seq_len(cu_seqlens_padded.tolist())
@@ -118,6 +124,12 @@ class BaseTrainingEngine(ABC):
else:
input_ids, attention_mask, position_ids, process_batch = self._get_forward_batch_info(batch_iter)
output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
+ if post_process:
+ if cp_size > 1:
+ output_list = [torch.empty_like(output) for _ in range(cp_size)]
+ torch.distributed.all_gather(output_list, output.detach(), group=get_parallel_state().get_context_parallel_group())
+ output_list[get_parallel_state().get_context_parallel_rank()] = output
+ output = torch.cat(output_list, dim=1)
output.div_(self.temperature)
return output, partial(self.loss_func.compute_loss, batch=process_batch, forward_only=forward_only)
@@ -146,6 +158,7 @@ class BaseTrainingEngine(ABC):
input_ids = batch['input_ids']
attention_mask_1d = generate_mask(input_ids, batch['prompt_length'] + batch['response_length']).to(
input_ids.device)
+ cp_size = get_parallel_state().get_context_parallel_world_size()
if self.use_remove_padding:
tp_size = get_parallel_state().get_tensor_model_parallel_world_size()
input_ids, position_ids, seqlens_in_batch, cu_seqlens_padded = preprocess_packed_seqs(
@@ -154,7 +167,18 @@ class BaseTrainingEngine(ABC):
else:
position_ids = torch.tensor(generate_position_ids(input_ids)).to(input_ids.device)
attention_mask = get_tune_attention_mask(attention_mask_1d)
- return input_ids, attention_mask, position_ids, batch
+ if cp_size > 1:
+ batch_for_cp = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ 'position_ids': position_ids
+ }
+ batch_cp = get_batch_on_this_cp_rank(self.context_parallel_algo, self.context_parallel_size, batch_for_cp)
+ input_ids = batch_cp['input_ids']
+ attention_mask = None
+ position_ids = batch_cp['position_ids']
+
+ return input_ids, attention_mask, position_ids, batch
def post_process_forward_backward_output(self, output: [torch.Tensor],
batch: Dict[str, torch.Tensor]) -> torch.Tensor:
diff --git a/mindspeed_rl/trainer/utils/parallel_state.py b/mindspeed_rl/trainer/utils/parallel_state.py
index 788e287f4dee50ed012576ccd648a4587703dde5..5e47fcfb9614df678257d0538a186d7dff67903d 100644
--- a/mindspeed_rl/trainer/utils/parallel_state.py
+++ b/mindspeed_rl/trainer/utils/parallel_state.py
@@ -41,6 +41,17 @@ def get_tensor_model_parallel_rank(mpu, use_vllm=False):
return mpu.get_tensor_model_parallel_rank()
+def get_context_parallel_rank(mpu, use_vllm=False):
+ if use_vllm:
+ from vllm.distributed import parallel_state as vpu
+ if not hasattr(vpu, "get_context_parallel_rank"):
+ vpu = mpu
+ return vpu.get_context_parallel_rank()
+ else:
+ return mpu.get_context_parallel_rank()
+
+
+
def get_tensor_model_parallel_src_rank(mpu, use_vllm=False):
if use_vllm:
from vllm.distributed import parallel_state as vpu
@@ -49,6 +60,13 @@ def get_tensor_model_parallel_src_rank(mpu, use_vllm=False):
return mpu.get_tensor_model_parallel_src_rank()
+def get_context_parallel_src_rank(mpu, use_vllm=False):
+ if use_vllm:
+ raise NotImplementedError("not implememted yet.")
+ else:
+ return mpu.get_context_parallel_global_ranks()[0]
+
+
def get_tensor_model_parallel_group(mpu, use_vllm=False):
if use_vllm:
from vllm.distributed import parallel_state as vpu
@@ -57,10 +75,35 @@ def get_tensor_model_parallel_group(mpu, use_vllm=False):
return mpu.get_tensor_model_parallel_group()
+def get_context_parallel_group(mpu, use_vllm=False):
+ if use_vllm:
+ raise NotImplementedError("not implememted yet.")
+ else:
+ return mpu.get_context_parallel_group()
+
+
def get_model_parallel_group(mpu, use_vllm=False):
if use_vllm:
import vllm
from vllm.distributed import parallel_state as vpu
return vpu.get_tensor_model_parallel_group().device_group
else:
- return mpu.get_model_parallel_group()
\ No newline at end of file
+ return mpu.get_model_parallel_group()
+
+
+def get_tensor_and_context_parallel_rank(mpu, use_vllm=False):
+ """Return caller's rank in the joint tensor-model-parallel and context-parallel group."""
+ if use_vllm:
+ from vllm.distributed import parallel_state as vpu
+ return vpu.get_tensor_model_parallel_rank()
+ else:
+ return mpu.get_tensor_and_context_parallel_rank()
+
+
+def get_tensor_and_context_parallel_group(mpu, use_vllm=False):
+ """Get the tensor- and context-parallel group the caller rank belongs to."""
+ if use_vllm:
+ from vllm.distributed import parallel_state as vpu
+ return vpu.get_tensor_model_parallel_group().device_group
+ else:
+ return mpu.get_tensor_and_context_parallel_group()
\ No newline at end of file
diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py
index 44d5a631aab266aee2432555c2dab056ff456074..2358dd0b6e312cd63871b694290fb20dd7f7055d 100644
--- a/mindspeed_rl/utils/utils.py
+++ b/mindspeed_rl/utils/utils.py
@@ -434,6 +434,23 @@ def get_attr_wrapped_model(model, attr, allow_none=True, return_model_obj=False)
return getattr(model, attr)
+def get_batch_on_this_cp_rank(context_parallel_algo, context_parallel_size, batch):
+ """ Slice batch input along sequence dimension into multiple chunks,
+ which are parallelized across GPUs in a context parallel group.
+ """
+ from mindspeed.utils import (set_actual_seq_len, set_position_ids,
+ _get_batch_on_this_cp_rank_in_ulysses_cp)
+
+ if context_parallel_size <= 1:
+ return batch
+
+ if context_parallel_algo == 'ulysses_cp_algo':
+ batch = _get_batch_on_this_cp_rank_in_ulysses_cp(batch)
+ else:
+ raise NotImplementedError("only support ulysses_cp_algo.")
+ return batch
+
+
def get_grpo_profiler(profiler_config, role: str = None):
args = profiler_config
if not args or not args.profile:
diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py
index 69e5cd17f97cbe71d628a1d5b4403b5022bf75ab..1acdd54a6388bb658e3d21bac1ab6663c83979d2 100644
--- a/mindspeed_rl/workers/actor_hybrid_worker.py
+++ b/mindspeed_rl/workers/actor_hybrid_worker.py
@@ -111,6 +111,8 @@ class ActorHybridWorkerBase(BaseWorker):
micro_batch_size=self.megatron_config.micro_batch_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=megatron_module['set_actual_seq_len'],
+ context_parallel_algo=self.megatron_config.context_parallel_algo,
+ context_parallel_size=self.megatron_config.context_parallel_size,
entropy_coeff=self.rl_config.entropy_coeff,
kl_penalty=self.rl_config.kl_penalty,
temperature=self.generate_config.sampling_config["temperature"]
@@ -172,6 +174,8 @@ class ActorHybridWorkerBase(BaseWorker):
experience_columns,
experience_count,
self.megatron_config.tensor_model_parallel_size,
+ self.megatron_config.context_parallel_size,
+ self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=False)
@@ -184,7 +188,7 @@ class ActorHybridWorkerBase(BaseWorker):
self.args.consumed_train_samples += self.megatron_config.global_batch_size // self.rl_config.n_samples_per_prompt
self.num_floating_point_operations_so_far += num_floating_point_operations(self.args,
self.megatron_config.global_batch_size)
- if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0:
+ if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0 and self.parallel_state.get_context_parallel_rank() == 0:
ray.get(self.td.update_metrics.remote(value=metrics, cumulate=True))
ray.get(
self.td.update_metrics.remote(
@@ -243,6 +247,8 @@ class ActorHybridWorkerBase(BaseWorker):
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
+ cp_size=self.megatron_config.context_parallel_size,
+ cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
use_vllm=True
)
@@ -322,6 +328,8 @@ class ActorHybridWorkerBase(BaseWorker):
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
+ cp_size=self.megatron_config.context_parallel_size,
+ cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=False)
diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py
index 51d91a3624a4c8da89cb922973f033d48c57a411..d5aa486f62fc90d992a1c18986b6b407028e0a28 100644
--- a/mindspeed_rl/workers/base_worker.py
+++ b/mindspeed_rl/workers/base_worker.py
@@ -10,6 +10,7 @@ import socket
import torch
import torch_npu
import ray
+import torch.distributed as dist
from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import get_vllm_tp_group_ranks
from mindspeed_rl.utils.loggers import Loggers
@@ -26,8 +27,11 @@ from mindspeed_rl.trainer.utils.parallel_state import (
is_pipeline_last_stage,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
+ get_context_parallel_rank,
get_tensor_model_parallel_src_rank,
- get_model_parallel_group
+ get_model_parallel_group,
+ get_context_parallel_src_rank,
+ get_context_parallel_group
)
from mindspeed_rl.utils.compute import set_parallel_state, set_vocab_parallel
from mindspeed_rl.utils.utils import get_current_dp_range_indexes
@@ -168,12 +172,24 @@ class BaseWorker(BaseRayWorker, ABC):
else:
current_device = next(self.model[0].parameters()).device
status = torch.tensor(0, device=current_device)
- if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \
- get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0:
+
+ rank_flg = False
+ if not use_vllm:
+ rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0)
+ else:
+ rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0)
+ if rank_flg:
status = torch.tensor(int(not ray.get(self.td.all_consumed.remote(experience_consumer_stage))),
device=current_device)
torch.distributed.all_reduce(status, group=get_model_parallel_group(self.parallel_state, use_vllm),
op=torch.distributed.ReduceOp.MAX)
+ if not use_vllm:
+ torch.distributed.all_reduce(status, group=get_context_parallel_group(self.parallel_state, use_vllm),
+ op=torch.distributed.ReduceOp.MAX)
+
return status
def setup_distributed_rank(self):
@@ -223,17 +239,26 @@ class BaseWorker(BaseRayWorker, ABC):
@mstx_timer_decorator
def dispatch_transfer_dock_data(self, experience_consumer_stage,
- experience_columns, experience_count, tp_size=1,
+ experience_columns, experience_count, tp_size=1, cp_size=1, cp_algo=None,
use_vllm=False, indexes=None,
get_n_samples=True):
pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
batch_data = {}
batch_data_length = {}
- # make sure that all ranks in tp/pp group enter dispatch_transfer_dock_data,
+ # make sure that all ranks in cp/tp/pp group enter dispatch_transfer_dock_data,
# in case of rank0 get_experience before other ranks judge td.all_consumed
- if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \
- get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0:
+
+ rank_flg = False
+ if not use_vllm:
+ rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0)
+ else:
+ rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and
+ get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0)
+
+ if rank_flg:
batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns,
experience_count, indexes=indexes,
get_n_samples=get_n_samples)) # cpu数据
@@ -244,11 +269,17 @@ class BaseWorker(BaseRayWorker, ABC):
else:
index = torch.empty(experience_count, device=torch.cuda.current_device(), dtype=torch.int64)
- # # 传输index, 并判断是否取出了数据
torch.distributed.broadcast(
index, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)
)
+
+ if not use_vllm:
+ torch.distributed.broadcast(
+ index, get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm)
+ )
+
torch.distributed.broadcast(
index, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)
@@ -257,13 +288,11 @@ class BaseWorker(BaseRayWorker, ABC):
if index[0].item() == -1:
return None, None
- if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \
- get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0:
+ if rank_flg:
batch_data, batch_data_length = pack_experience_columns(batch_data, experience_count)
for key in experience_columns:
- if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and \
- get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0:
+ if rank_flg:
batch_data_shape = torch.tensor(batch_data[key].shape,
dtype=torch.int64, device=torch.cuda.current_device())
@@ -280,7 +309,7 @@ class BaseWorker(BaseRayWorker, ABC):
batch_data_dtype = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64)
batch_data_length_shape = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64)
- # 传输tensor数据形状和类型
+ # TP domain sync
torch.distributed.broadcast(
batch_data_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)
@@ -289,6 +318,23 @@ class BaseWorker(BaseRayWorker, ABC):
batch_data_dtype, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)
)
+ torch.distributed.broadcast(batch_data_length_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_tensor_model_parallel_group(self.parallel_state, use_vllm))
+
+ # CP domain sync
+ if not use_vllm:
+ torch.distributed.broadcast(
+ batch_data_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm)
+ )
+ torch.distributed.broadcast(
+ batch_data_dtype, get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm)
+ )
+ torch.distributed.broadcast(batch_data_length_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm))
+
+ # PP domain sync
torch.distributed.broadcast(
batch_data_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)
@@ -297,14 +343,11 @@ class BaseWorker(BaseRayWorker, ABC):
batch_data_dtype, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)
)
-
- torch.distributed.broadcast(batch_data_length_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
- group=get_tensor_model_parallel_group(self.parallel_state, use_vllm))
torch.distributed.broadcast(batch_data_length_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm))
- if get_tensor_model_parallel_rank(self.parallel_state, use_vllm) != 0 or \
- get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) != 0:
+
+ if not rank_flg:
if batch_data_dtype == 1:
batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1],
device=torch.cuda.current_device(),
@@ -320,6 +363,13 @@ class BaseWorker(BaseRayWorker, ABC):
batch_data[key].cuda(), get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)
)
+
+ if not use_vllm:
+ torch.distributed.broadcast(
+ batch_data[key].cuda(), get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm)
+ )
+
torch.distributed.broadcast(
batch_data[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)
@@ -328,13 +378,19 @@ class BaseWorker(BaseRayWorker, ABC):
torch.distributed.broadcast(batch_data_length[key].cuda(), get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_tensor_model_parallel_group(self.parallel_state, use_vllm))
+ if not use_vllm:
+ torch.distributed.broadcast(batch_data_length[key].cuda(), get_context_parallel_src_rank(self.parallel_state, use_vllm),
+ group=get_context_parallel_group(self.parallel_state, use_vllm))
+
torch.distributed.broadcast(batch_data_length[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm),
group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm))
index_without_pad = index.cpu().numpy().tolist()[:batch_data_shape[0]]
if batch_data:
- padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, tp_size)
+ if cp_algo in ['ulysses_cp_algo']:
+ padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, tp_size * cp_size)
+
return padded_batch_data, index_without_pad
else:
return {}, []
diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py
index a7035ffd12c8c79be3e20d23986e2e658dabb0a6..a0ae627bc7d24778b3312b8f905a54452988619b 100644
--- a/mindspeed_rl/workers/integrated_worker.py
+++ b/mindspeed_rl/workers/integrated_worker.py
@@ -103,6 +103,8 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
+ context_parallel_algo=self.megatron_config.context_parallel_algo,
+ context_parallel_size=self.megatron_config.context_parallel_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=megatron_module['set_actual_seq_len'],
temperature=self.generate_config.sampling_config["temperature"]
diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py
index 04b87738ddc55cc34bac8b3bb9e826d6f00a81b3..fc5423777854caed3174a14870ad398178c8d6ff 100644
--- a/mindspeed_rl/workers/reference_woker.py
+++ b/mindspeed_rl/workers/reference_woker.py
@@ -15,9 +15,12 @@ from mindspeed_rl.utils.pad_process import truncate_rows
from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.workers.base_worker import BaseWorker
from mindspeed_rl.utils.compute import get_parallel_state
-from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank
+from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank, get_context_parallel_rank
+from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.utils.utils import mstx_timer_decorator
+logger = Loggers(__name__)
+
class ReferenceWorkerBase(BaseWorker):
"""
@@ -80,6 +83,8 @@ class ReferenceWorkerBase(BaseWorker):
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
+ context_parallel_algo=self.megatron_config.context_parallel_algo,
+ context_parallel_size=self.megatron_config.context_parallel_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=megatron_module['set_actual_seq_len'],
temperature=self.generate_config.sampling_config["temperature"]
@@ -102,10 +107,12 @@ class ReferenceWorkerBase(BaseWorker):
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
+ cp_size=self.megatron_config.context_parallel_size,
+ cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=False)
-
+
if not start_time_defined:
start_time = time.time()
start_time_defined = True
@@ -137,7 +144,7 @@ class ReferenceWorkerBase(BaseWorker):
parallel_state = get_parallel_state()
use_vllm = False
- if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0:
+ if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0:
ref_end_time = time.time()
ray.get(
self.td.update_metrics.remote(
diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py
index 182b695f118e1345ad77adc9345017b46abd340e..eebfb95df5eb5aa66f1389220f06ffce115be58e 100644
--- a/mindspeed_rl/workers/reward_woker.py
+++ b/mindspeed_rl/workers/reward_woker.py
@@ -14,7 +14,7 @@ from mindspeed_rl.trainer.utils.compute_utils import get_last_reward
from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.workers.base_worker import BaseWorker
from mindspeed_rl.utils.compute import get_parallel_state
-from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank
+from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank, get_context_parallel_rank
class RewardWorkerBase(BaseWorker):
@@ -74,6 +74,8 @@ class RewardWorkerBase(BaseWorker):
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
+ context_parallel_algo=self.megatron_config.context_parallel_algo,
+ context_parallel_size=self.megatron_config.context_parallel_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=megatron_module['set_actual_seq_len'],
temperature=self.generate_config.sampling_config["temperature"]
@@ -104,6 +106,8 @@ class RewardWorkerBase(BaseWorker):
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
+ cp_size=self.megatron_config.context_parallel_size,
+ cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
)
@@ -139,7 +143,7 @@ class RewardWorkerBase(BaseWorker):
)
parallel_state = get_parallel_state()
use_vllm = False
- if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0:
+ if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0:
rwd_end_time = time.time()
ray.get(
self.td.update_metrics.remote(
diff --git a/tests/ut/workers/test_base_worker.py b/tests/ut/workers/test_base_worker.py
index f6bc36e976664f9d9beae3cfdaae9e5fa28f01dd..28aca380450f06b253b7863471a44441fe7a1922 100644
--- a/tests/ut/workers/test_base_worker.py
+++ b/tests/ut/workers/test_base_worker.py
@@ -148,4 +148,4 @@ class TestBaseWorker(DistributedTest):
_, _ = worker.dispatch_transfer_dock_data(experience_consumer_stage,
experience_columns, experience_count)
- assert mock_broadcast.call_count == 2
+ assert mock_broadcast.call_count == 3