diff --git a/configs/grpo_qwen25_32b_A3.yaml b/configs/grpo_qwen25_32b_A3.yaml index 7a94c7038e2b7d5124bc32f8dc571a7b2c099a36..48612abe1bda5e9cbbe70b65e003f102f615d57c 100644 --- a/configs/grpo_qwen25_32b_A3.yaml +++ b/configs/grpo_qwen25_32b_A3.yaml @@ -61,7 +61,7 @@ rl_config: ref_forward_micro_batch_size: 16 reward_dispatch_size: 64 adv_dispatch_size: 64 - actor_update_dispatch_size: 16 + actor_update_dispatch_size: 1024 use_integrated_worker: true gamma: 1.0 lam: 0.95 diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index a43dbca952a53f648812c57604590e7d487e1c52..4a8e8826cba8139a1ac9a3980c2a7a68efc573e5 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -350,6 +350,16 @@ def validate_rl_args( rl_config.critic_update_dispatch_size, "Critic Update") + # 若指定了自定义的actor_update_dispatch_size,检查 actor_update_dispatch_size 是否符合 on_policy/off_policy 策略要求 + if rl_config.actor_update_dispatch_size: + if rl_config.actor_update_dispatch_size < rl_config.mini_batch_size / actor_data_parallel_size: + raise ValueError( + f"actor_update_dispatch_size={rl_config.actor_update_dispatch_size} " + f"must be >= mini_batch_size/actor_data_parallel_size " + f"({rl_config.mini_batch_size}/{actor_data_parallel_size}=" + f"{int(rl_config.mini_batch_size/actor_data_parallel_size)})" + ) + if rl_config.filter_groups_enable: # 若开启dapo动态采样,update的gbs=filter_groups_train_batch_size rl_config.actor_update_dispatch_size = ( diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 342846d4bcfe300b21c0bb01be8cd7a5bc5ad12a..d955350020316ebf80872b47cea335ee83e17dd3 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -211,20 +211,10 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns = ['responses', 'advantages', 'old_log_prob', 'input_ids', 'response_length', 'prompt_length'] else: experience_columns = ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', 'input_ids', 'response_length', 'prompt_length'] - if is_multimodal(): experience_columns.extend(['attention_mask', 'position_ids']) - experience_count = self.rl_config.actor_update_dispatch_size - else: - experience_count = ( - self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size() - ) - - if self.rl_config.filter_groups_enable: - experience_count = ( - self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt // - self.parallel_state.get_data_parallel_world_size() - ) + + experience_count = self.rl_config.actor_update_dispatch_size if skip_actor_log_prob: experience_columns.remove('old_log_prob')