From c220ed643a7c60caca467efb1c106408ebb51b42 Mon Sep 17 00:00:00 2001 From: nurxat Date: Thu, 3 Jul 2025 20:26:36 +0800 Subject: [PATCH 1/2] dynamic --- docs/features/remove_padding.md | 14 ++++-- mindspeed_rl/config_cls/rl_config.py | 8 +++- .../models/base/base_training_engine.py | 40 +++++++++------- mindspeed_rl/models/loss/base_loss_func.py | 48 +++++++++---------- mindspeed_rl/workers/actor_hybrid_worker.py | 21 ++++---- mindspeed_rl/workers/critic_worker.py | 7 +-- mindspeed_rl/workers/integrated_worker.py | 11 +++-- mindspeed_rl/workers/reference_woker.py | 11 +++-- mindspeed_rl/workers/reward_woker.py | 8 ++-- ...est_grpo_trainer_qwen25_7b_integrated.yaml | 4 +- 10 files changed, 91 insertions(+), 81 deletions(-) diff --git a/docs/features/remove_padding.md b/docs/features/remove_padding.md index 089a8e08..514a4346 100644 --- a/docs/features/remove_padding.md +++ b/docs/features/remove_padding.md @@ -68,8 +68,8 @@ rl_config: ## 参数说明: -`max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。 -`dynamic_max_batch_size` 用于限制最大的 micro batch,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。 +`ref_max_packing_token_size`, `actor_max_packing_token_size`, `update_max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。 +`ref_dynamic_max_batch_size`, `actor_dynamic_max_batch_size`, `update_dynamic_max_batch_size` 是控制 Dynamic batch size 分箱之后每个批次中最大的序列条数 micro batch size,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。 **使用限制**:每条样本的 token 长度必须满足: ```text @@ -80,7 +80,7 @@ prompt_length[i] + response_length[i] <= max_packing_token_size ```text max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling_config.max_tokens) * 2 ``` -`dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置或减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。 +`*_dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。 二者可以根据实际需求调整。 @@ -93,8 +93,12 @@ max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling ```yaml rl_config: use_dynamic_bsz: true - max_packing_token_size: 8192 - dynamic_max_batch_size: 8 # 可选参数 + ref_max_packing_token_size: 8192 + ref_dynamic_max_batch_size: 8 # 可选参数 + actor_max_packing_token_size: 8192 + actor_dynamic_max_batch_size: 8 # 可选参数 + update_max_packing_token_size: 8192 + update_dynamic_max_batch_size: 8 # 可选参数 ``` # 📦 数据并行负载均衡(DP Batch Balance)特性 diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 9cf87bb2..3a7a9c5f 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -121,8 +121,12 @@ class RLConfig(BaseConfig): self.mini_batch_size = 1 self.use_dynamic_bsz = False - self.max_packing_token_size = 4096 - self.dynamic_max_batch_size = None + self.ref_max_packing_token_size = 4096 + self.actor_max_packing_token_size = 4096 + self.update_max_packing_token_size = 4096 + self.ref_dynamic_max_batch_size = None + self.actor_dynamic_max_batch_size = None + self.update_dynamic_max_batch_size = None # token level loss self.token_level_loss = True diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index 044fd36c..1330daee 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -62,16 +62,8 @@ class BaseTrainingEngine(ABC): temperature: float = 1.0, role: str = None, micro_batch_size: int = 1, - use_dynamic_bsz: bool = False, - max_packing_token_size: bool = 4096, - dynamic_max_batch_size: int = None, - use_remove_padding: bool = False, - set_actual_seq_len: Callable = None, - get_actual_seq_len: Callable = None, - set_position_ids: Callable = None, forward_backward_func: Callable = None, entropy_coeff: float = 0.0, - context_parallel_size: int = 1, kl_penalty: str = "low_var_kl", token_level_loss: bool = False, clip_higher_enable: bool = False, @@ -81,13 +73,6 @@ class BaseTrainingEngine(ABC): **kwargs): self.forward_backward_func = forward_backward_func self.micro_batch_size = micro_batch_size - self.use_dynamic_bsz = use_dynamic_bsz - self.max_packing_token_size = max_packing_token_size - self.dynamic_max_batch_size = dynamic_max_batch_size - self.use_remove_padding = use_remove_padding - self.set_actual_seq_len = set_actual_seq_len - self.get_actual_seq_len = get_actual_seq_len - self.set_position_ids = set_position_ids self.model = model self.megatron_config = megatron_config self.optimizer = optimizer @@ -102,7 +87,6 @@ class BaseTrainingEngine(ABC): self.kl_penalty = kl_penalty self.clip_ratio = clip_ratio self.entropy_coeff = entropy_coeff - self.context_parallel_size = context_parallel_size self.temperature = temperature self.token_level_loss = token_level_loss self.clip_higher_enable = clip_higher_enable @@ -111,7 +95,21 @@ class BaseTrainingEngine(ABC): self.cliprange_value = cliprange_value self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role) self.kwargs = kwargs - + + self.use_remove_padding = kwargs.get('use_remove_padding', False) + self.use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False) + self.max_packing_token_size = kwargs.get('ref_max_packing_token_size', None) + self.dynamic_max_batch_size = kwargs.get('ref_dynamic_max_batch_size', None) + if self.max_packing_token_size is None: + self.max_packing_token_size = {'actor': kwargs.get('actor_max_packing_token_size', None), + 'update': kwargs.get('update_max_packing_token_size', None)} + self.dynamic_max_batch_size = {'actor': kwargs.get('actor_dynamic_max_batch_size', None), + 'update': kwargs.get('update_dynamic_max_batch_size', None)} + self.context_parallel_size = kwargs.get('context_parallel_size', 1) + self.set_actual_seq_len = kwargs.get('set_actual_seq_len', None) + self.get_actual_seq_len = kwargs.get('get_actual_seq_len', None) + self.set_position_ids = kwargs.get('set_position_ids', None) + @staticmethod def _split_batches(batch: Dict, batch_size: int, shuffle_mini_batch: bool, dim: int = 0, keep_list: bool = False) -> List[Dict]: batches = [] @@ -149,7 +147,13 @@ class BaseTrainingEngine(ABC): def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False): if self.use_dynamic_bsz: - batches, indices = self._split_batches_with_dynamic_bsz(batch, self.max_packing_token_size, self.dynamic_max_batch_size) + if isinstance(self.max_packing_token_size, dict): + max_packing_token_size = self.max_packing_token_size['actor'] if forward_only else self.max_packing_token_size['update'] # actor forward or update + dynamic_max_batch_size = self.dynamic_max_batch_size['actor'] if forward_only else self.dynamic_max_batch_size['update'] + else: + max_packing_token_size = self.max_packing_token_size # reference forward + dynamic_max_batch_size = self.dynamic_max_batch_size + batches, indices = self._split_batches_with_dynamic_bsz(batch, max_packing_token_size, dynamic_max_batch_size) else: batches = self._split_batches(batch, batch_size=self.micro_batch_size, shuffle_mini_batch=self.shuffle_mini_batch) diff --git a/mindspeed_rl/models/loss/base_loss_func.py b/mindspeed_rl/models/loss/base_loss_func.py index 77df2821..0b92f2be 100644 --- a/mindspeed_rl/models/loss/base_loss_func.py +++ b/mindspeed_rl/models/loss/base_loss_func.py @@ -45,35 +45,31 @@ class BaseLossFunc(ABC): return responses, logits def compute_log_probs(self, output, batch: Dict[str, torch.Tensor], update=False, **kwargs): + # 当 use_remove_padding 为 True 时,输入的 output 是一个 pack 后的 tensor,形状为 [1, seq_len, vocab_size]。 + # 计算 log_probs 时需要根据 seqlens_in_batch 和 cu_seqlens_padded 对 output 进行切片。 use_remove_padding = kwargs.get('use_remove_padding', False) - if use_remove_padding: - seqlens_in_batch = kwargs.get('seqlens_in_batch', None) - cu_seqlens_padded = kwargs.get('cu_seqlens_padded', None) - batch_size = seqlens_in_batch.shape[0] - log_probs_list = [] - entropy_list = [] - for i in range(batch_size): + seqlens_in_batch = kwargs.get('seqlens_in_batch', None) + cu_seqlens_padded = kwargs.get('cu_seqlens_padded', None) + batch_size = seqlens_in_batch.shape[0] if use_remove_padding else output.shape[0] + log_probs_list = [] + entropy_list = [] + for i in range(batch_size): + if use_remove_padding: start = cu_seqlens_padded[i].item() length = seqlens_in_batch[i].item() single_output = output[0, start:start + length].unsqueeze(0) # [1, length, vocab_size] - single_batch = {key: value[i].unsqueeze(0) for key, value in batch.items()} - response, logits = self._get_compute_log_probs_input(single_output, single_batch) - single_log_probs = compute_log_probs(logits, response) - log_probs_list.append(single_log_probs) - if update: - single_entropy = vocab_parallel_entropy(logits) - entropy_list.append(single_entropy) - log_probs = torch.cat(log_probs_list, dim=0) - if update: - entropy = torch.cat(entropy_list, dim=0) - return log_probs, entropy else: - return log_probs - else: - responses, logits = self._get_compute_log_probs_input(output, batch) - log_probs = compute_log_probs(logits, responses) + single_output = output[i].unsqueeze(0) + single_batch = {key: value[i].unsqueeze(0) for key, value in batch.items()} + response, logits = self._get_compute_log_probs_input(single_output, single_batch) + single_log_probs = compute_log_probs(logits, response) + log_probs_list.append(single_log_probs) if update: - entropy = vocab_parallel_entropy(logits) - return log_probs, entropy - else: - return log_probs \ No newline at end of file + single_entropy = vocab_parallel_entropy(logits) + entropy_list.append(single_entropy) + log_probs = torch.cat(log_probs_list, dim=0) + if update: + entropy = torch.cat(entropy_list, dim=0) + return log_probs, entropy + else: + return log_probs \ No newline at end of file diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index a7d433d8..92b9e5c2 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -119,21 +119,24 @@ class ActorHybridWorkerBase(BaseWorker): forward_backward_func=self.forward_backward_func, clip_ratio=self.rl_config.clip_ratio, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, - use_remove_padding=self.rl_config.use_remove_padding, - set_actual_seq_len=self.set_actual_seq_len, - get_actual_seq_len=self.get_actual_seq_len, - set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, entropy_coeff=self.rl_config.entropy_coeff, kl_penalty=self.rl_config.kl_penalty, temperature=self.generate_config.sampling_config["temperature"], token_level_loss=self.rl_config.token_level_loss, clip_higher_enable=self.rl_config.clip_higher_enable, clip_ratio_low=self.rl_config.clip_ratio_low, - clip_ratio_high=self.rl_config.clip_ratio_high + clip_ratio_high=self.rl_config.clip_ratio_high, + + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + actor_max_packing_token_size=self.rl_config.actor_max_packing_token_size, + update_max_packing_token_size=self.rl_config.update_max_packing_token_size, + actor_dynamic_max_batch_size=self.rl_config.actor_dynamic_max_batch_size, + update_dynamic_max_batch_size=self.rl_config.update_dynamic_max_batch_size, + set_actual_seq_len=self.set_actual_seq_len, + get_actual_seq_len=self.get_actual_seq_len, + set_position_ids=self.set_position_ids, + context_parallel_size=self.megatron_config.context_parallel_size ) self.empty_cache() self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role) diff --git a/mindspeed_rl/workers/critic_worker.py b/mindspeed_rl/workers/critic_worker.py index 5bd6b80c..f564e7fb 100644 --- a/mindspeed_rl/workers/critic_worker.py +++ b/mindspeed_rl/workers/critic_worker.py @@ -69,7 +69,6 @@ class CriticWorkerBase(BaseWorker): self.critic_offloader.offload_grad() self.critic_offloader.offload_param() - megatron_module = self.get_megatron_module() self.critic = Critic( self.model, optimizer=self.optimizer, @@ -81,11 +80,7 @@ class CriticWorkerBase(BaseWorker): epochs=self.rl_config.epochs, shuffle_mini_batch=self.rl_config.shuffle_mini_batch, forward_backward_func=self.forward_backward_func, - micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - use_remove_padding=self.rl_config.use_remove_padding, - set_actual_seq_len=megatron_module['set_actual_seq_len'] + micro_batch_size=self.megatron_config.micro_batch_size ) self.empty_cache() diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index b78c0a7c..6347663d 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -103,15 +103,16 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + temperature=self.generate_config.sampling_config["temperature"], + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size, + ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) MsProbe.config_init(self.msprobe_config) diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index aa75cdf0..5d36a806 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -82,15 +82,16 @@ class ReferenceWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + temperature=self.generate_config.sampling_config["temperature"], + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size, + ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None): diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 2e05fbfa..6fc6ed52 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -73,15 +73,15 @@ class RewardWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, + temperature=self.generate_config.sampling_config["temperature"], use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + max_packing_token_size=self.rl_config.ref_max_packing_token_size, + dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) def init_transfer_dock(self, td, sampling_transfer_dock=None): diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml index 660b31cf..5e47753a 100644 --- a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml @@ -61,7 +61,9 @@ rl_config: blocking: true use_dp_batch_balance: true use_dynamic_bsz: true - max_packing_token_size: 8192 + ref_max_packing_token_size: 8192 + actor_max_packing_token_size: 8192 + update_max_packing_token_size: 8192 actor_forward_micro_batch_size: 8 ref_forward_micro_batch_size: 8 use_remove_padding: true -- Gitee From d852b83c525576a7d3f40d80bad62276630ddbb6 Mon Sep 17 00:00:00 2001 From: nurxat Date: Tue, 19 Aug 2025 17:23:33 +0800 Subject: [PATCH 2/2] dbbug --- mindspeed_rl/config_cls/rl_config.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 9cc15f2b..86b909d2 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -129,12 +129,14 @@ class RLConfig(BaseConfig): self.mini_batch_size = 1 self.use_dynamic_bsz = False - self.ref_max_packing_token_size = 4096 - self.actor_max_packing_token_size = 4096 - self.update_max_packing_token_size = 4096 - self.ref_dynamic_max_batch_size = None - self.actor_dynamic_max_batch_size = None - self.update_dynamic_max_batch_size = None + self.max_packing_token_size = 4096 + self.ref_max_packing_token_size = self.max_packing_token_size + self.actor_max_packing_token_size = self.max_packing_token_size + self.update_max_packing_token_size = self.max_packing_token_size + self.dynamic_max_batch_size = None + self.ref_dynamic_max_batch_size = self.dynamic_max_batch_size + self.actor_dynamic_max_batch_size = self.dynamic_max_batch_size + self.update_dynamic_max_batch_size = self.dynamic_max_batch_size self.log_max_throughput = True -- Gitee