diff --git a/docs/features/remove_padding.md b/docs/features/remove_padding.md index 089a8e088d3edc6078ace950d8c8d704b20e5c80..514a43461706698edb44da65d8ada06d72fb8da8 100644 --- a/docs/features/remove_padding.md +++ b/docs/features/remove_padding.md @@ -68,8 +68,8 @@ rl_config: ## 参数说明: -`max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。 -`dynamic_max_batch_size` 用于限制最大的 micro batch,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。 +`ref_max_packing_token_size`, `actor_max_packing_token_size`, `update_max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。 +`ref_dynamic_max_batch_size`, `actor_dynamic_max_batch_size`, `update_dynamic_max_batch_size` 是控制 Dynamic batch size 分箱之后每个批次中最大的序列条数 micro batch size,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。 **使用限制**:每条样本的 token 长度必须满足: ```text @@ -80,7 +80,7 @@ prompt_length[i] + response_length[i] <= max_packing_token_size ```text max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling_config.max_tokens) * 2 ``` -`dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置或减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。 +`*_dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。 二者可以根据实际需求调整。 @@ -93,8 +93,12 @@ max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling ```yaml rl_config: use_dynamic_bsz: true - max_packing_token_size: 8192 - dynamic_max_batch_size: 8 # 可选参数 + ref_max_packing_token_size: 8192 + ref_dynamic_max_batch_size: 8 # 可选参数 + actor_max_packing_token_size: 8192 + actor_dynamic_max_batch_size: 8 # 可选参数 + update_max_packing_token_size: 8192 + update_dynamic_max_batch_size: 8 # 可选参数 ``` # 📦 数据并行负载均衡(DP Batch Balance)特性 diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 8c2637666d5dcc7e7fdc64b265d9d14371aee6e9..86b909d270eb2a0da08828a81f8f2276121dfc33 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -130,8 +130,15 @@ class RLConfig(BaseConfig): self.use_dynamic_bsz = False self.max_packing_token_size = 4096 - self.log_max_throughput = True + self.ref_max_packing_token_size = self.max_packing_token_size + self.actor_max_packing_token_size = self.max_packing_token_size + self.update_max_packing_token_size = self.max_packing_token_size self.dynamic_max_batch_size = None + self.ref_dynamic_max_batch_size = self.dynamic_max_batch_size + self.actor_dynamic_max_batch_size = self.dynamic_max_batch_size + self.update_dynamic_max_batch_size = self.dynamic_max_batch_size + + self.log_max_throughput = True # token level loss self.token_level_loss = True diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index f0da389375065baabfe5ef45a1e798ed7bafadd9..b1aa01c4fd455737130e067d17c37d2c6461ca49 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -62,16 +62,8 @@ class BaseTrainingEngine(ABC): temperature: float = 1.0, role: str = None, micro_batch_size: int = 1, - use_dynamic_bsz: bool = False, - max_packing_token_size: bool = 4096, - dynamic_max_batch_size: int = None, - use_remove_padding: bool = False, - set_actual_seq_len: Callable = None, - get_actual_seq_len: Callable = None, - set_position_ids: Callable = None, forward_backward_func: Callable = None, entropy_coeff: float = 0.0, - context_parallel_size: int = 1, kl_penalty: str = "low_var_kl", token_level_loss: bool = False, clip_higher_enable: bool = False, @@ -81,13 +73,6 @@ class BaseTrainingEngine(ABC): **kwargs): self.forward_backward_func = forward_backward_func self.micro_batch_size = micro_batch_size - self.use_dynamic_bsz = use_dynamic_bsz - self.max_packing_token_size = max_packing_token_size - self.dynamic_max_batch_size = dynamic_max_batch_size - self.use_remove_padding = use_remove_padding - self.set_actual_seq_len = set_actual_seq_len - self.get_actual_seq_len = get_actual_seq_len - self.set_position_ids = set_position_ids self.model = model self.megatron_config = megatron_config self.optimizer = optimizer @@ -102,7 +87,6 @@ class BaseTrainingEngine(ABC): self.kl_penalty = kl_penalty self.clip_ratio = clip_ratio self.entropy_coeff = entropy_coeff - self.context_parallel_size = context_parallel_size self.temperature = temperature self.token_level_loss = token_level_loss self.clip_higher_enable = clip_higher_enable @@ -111,7 +95,21 @@ class BaseTrainingEngine(ABC): self.cliprange_value = cliprange_value self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role) self.kwargs = kwargs - + + self.use_remove_padding = kwargs.get('use_remove_padding', False) + self.use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False) + self.max_packing_token_size = kwargs.get('ref_max_packing_token_size', None) + self.dynamic_max_batch_size = kwargs.get('ref_dynamic_max_batch_size', None) + if self.max_packing_token_size is None: + self.max_packing_token_size = {'actor': kwargs.get('actor_max_packing_token_size', None), + 'update': kwargs.get('update_max_packing_token_size', None)} + self.dynamic_max_batch_size = {'actor': kwargs.get('actor_dynamic_max_batch_size', None), + 'update': kwargs.get('update_dynamic_max_batch_size', None)} + self.context_parallel_size = kwargs.get('context_parallel_size', 1) + self.set_actual_seq_len = kwargs.get('set_actual_seq_len', None) + self.get_actual_seq_len = kwargs.get('get_actual_seq_len', None) + self.set_position_ids = kwargs.get('set_position_ids', None) + @staticmethod def _split_batches(batch: Dict, batch_size: int, shuffle_mini_batch: bool, dim: int = 0, keep_list: bool = False) -> List[Dict]: batches = [] @@ -149,7 +147,13 @@ class BaseTrainingEngine(ABC): def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False): if self.use_dynamic_bsz: - batches, indices = self._split_batches_with_dynamic_bsz(batch, self.max_packing_token_size, self.dynamic_max_batch_size) + if isinstance(self.max_packing_token_size, dict): + max_packing_token_size = self.max_packing_token_size['actor'] if forward_only else self.max_packing_token_size['update'] # actor forward or update + dynamic_max_batch_size = self.dynamic_max_batch_size['actor'] if forward_only else self.dynamic_max_batch_size['update'] + else: + max_packing_token_size = self.max_packing_token_size # reference forward + dynamic_max_batch_size = self.dynamic_max_batch_size + batches, indices = self._split_batches_with_dynamic_bsz(batch, max_packing_token_size, dynamic_max_batch_size) else: batches = self._split_batches(batch, batch_size=self.micro_batch_size, shuffle_mini_batch=self.shuffle_mini_batch) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index d959db448a77d31fbf2f0de1d01a522690db0ac9..9218462d9f9484138adeab384422d12fb9ebb894 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -130,21 +130,24 @@ class ActorHybridWorkerBase(BaseWorker): forward_backward_func=self.forward_backward_func, clip_ratio=self.rl_config.clip_ratio, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, - use_remove_padding=self.rl_config.use_remove_padding, - set_actual_seq_len=self.set_actual_seq_len, - get_actual_seq_len=self.get_actual_seq_len, - set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, entropy_coeff=self.rl_config.entropy_coeff, kl_penalty=self.rl_config.kl_penalty, temperature=self.generate_config.sampling_config["temperature"], token_level_loss=self.rl_config.token_level_loss, clip_higher_enable=self.rl_config.clip_higher_enable, clip_ratio_low=self.rl_config.clip_ratio_low, - clip_ratio_high=self.rl_config.clip_ratio_high + clip_ratio_high=self.rl_config.clip_ratio_high, + + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + actor_max_packing_token_size=self.rl_config.actor_max_packing_token_size, + update_max_packing_token_size=self.rl_config.update_max_packing_token_size, + actor_dynamic_max_batch_size=self.rl_config.actor_dynamic_max_batch_size, + update_dynamic_max_batch_size=self.rl_config.update_dynamic_max_batch_size, + set_actual_seq_len=self.set_actual_seq_len, + get_actual_seq_len=self.get_actual_seq_len, + set_position_ids=self.set_position_ids, + context_parallel_size=self.megatron_config.context_parallel_size ) self.empty_cache() self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role) diff --git a/mindspeed_rl/workers/critic_worker.py b/mindspeed_rl/workers/critic_worker.py index 188d671a859388e873e9bda1b88b6d40e6d668c4..b31061f38a39a493faddc7168136cc9ef6af717b 100644 --- a/mindspeed_rl/workers/critic_worker.py +++ b/mindspeed_rl/workers/critic_worker.py @@ -84,7 +84,6 @@ class CriticWorkerBase(BaseWorker): self.critic_offloader.offload_grad() self.critic_offloader.offload_param() - megatron_module = self.get_megatron_module() self.critic = Critic( self.model, megatron_config=self.megatron_config, @@ -99,6 +98,9 @@ class CriticWorkerBase(BaseWorker): forward_backward_func=self.forward_backward_func, clip_ratio=self.rl_config.clip_ratio, micro_batch_size=self.megatron_config.micro_batch_size, + entropy_coeff=self.rl_config.entropy_coeff, + cliprange_value=self.rl_config.cliprange_value, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, max_packing_token_size=self.rl_config.max_packing_token_size, dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, @@ -106,9 +108,7 @@ class CriticWorkerBase(BaseWorker): set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - entropy_coeff=self.rl_config.entropy_coeff, - cliprange_value=self.rl_config.cliprange_value + context_parallel_size=self.megatron_config.context_parallel_size ) self.empty_cache() self.critic_profiler = profiler_start(self.profiler_config, self.profiler_config.role) diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index 91d59f8c2ae23504c116efe6b5f07f5701a432a8..2bc9676f19c6e6c5be43b1b22721ce03d6cee860 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -103,15 +103,16 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + temperature=self.generate_config.sampling_config["temperature"], + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size, + ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) MsProbe.config_init(self.msprobe_config) diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 1781dba70eae125da24b676bdd5dbf181f77d5eb..6f47edf19556e33f073237c28a46fc5ed981bee9 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -82,15 +82,16 @@ class ReferenceWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + temperature=self.generate_config.sampling_config["temperature"], + use_remove_padding=self.rl_config.use_remove_padding, + use_dynamic_bsz=self.rl_config.use_dynamic_bsz, + ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size, + ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None): diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 01dcb591c8b9abc5ff8b5047bf3d326ce0290377..78c710b9483282a385efc33c80f01d598a9a31b3 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -73,15 +73,15 @@ class RewardWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, + temperature=self.generate_config.sampling_config["temperature"], use_dynamic_bsz=self.rl_config.use_dynamic_bsz, - max_packing_token_size=self.rl_config.max_packing_token_size, - dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size, + max_packing_token_size=self.rl_config.ref_max_packing_token_size, + dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size, use_remove_padding=self.rl_config.use_remove_padding, set_actual_seq_len=self.set_actual_seq_len, get_actual_seq_len=self.get_actual_seq_len, set_position_ids=self.set_position_ids, - context_parallel_size=self.megatron_config.context_parallel_size, - temperature=self.generate_config.sampling_config["temperature"] + context_parallel_size=self.megatron_config.context_parallel_size ) def init_transfer_dock(self, td, sampling_transfer_dock=None): diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml index 48bb4f7151f6c7c3d09fb6415dc245716c87c89d..40af50c823ec955d4bbc08bd43176ee94013453f 100644 --- a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml @@ -62,7 +62,9 @@ rl_config: blocking: true use_dp_batch_balance: true use_dynamic_bsz: true - max_packing_token_size: 8192 + ref_max_packing_token_size: 8192 + actor_max_packing_token_size: 8192 + update_max_packing_token_size: 8192 actor_forward_micro_batch_size: 8 ref_forward_micro_batch_size: 8 use_remove_padding: true