From 3caacff42a711081655eca62d3251879fc024c7e Mon Sep 17 00:00:00 2001 From: dangzzz <814207734@qq.com> Date: Thu, 4 Sep 2025 17:23:27 +0800 Subject: [PATCH 1/3] [bug_fix]actor_update_dispatch_size --- configs/grpo_qwen25_32b_A3.yaml | 2 +- mindspeed_rl/config_cls/validate_config.py | 9 +++++++++ mindspeed_rl/workers/actor_hybrid_worker.py | 9 ++------- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/configs/grpo_qwen25_32b_A3.yaml b/configs/grpo_qwen25_32b_A3.yaml index 7a94c703..48612abe 100644 --- a/configs/grpo_qwen25_32b_A3.yaml +++ b/configs/grpo_qwen25_32b_A3.yaml @@ -61,7 +61,7 @@ rl_config: ref_forward_micro_batch_size: 16 reward_dispatch_size: 64 adv_dispatch_size: 64 - actor_update_dispatch_size: 16 + actor_update_dispatch_size: 1024 use_integrated_worker: true gamma: 1.0 lam: 0.95 diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index a43dbca9..e2056096 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -367,6 +367,15 @@ def validate_rl_args( _validate_experience_ratio(actor_config.global_batch_size, rl_config.actor_update_dispatch_size, "Actor Update") + + # 检查 actor_update_dispatch_size 是否符合 on_policy/off_policy 策略要求 + if rl_config.actor_update_dispatch_size < rl_config.mini_batch_size / actor_data_parallel_size: + raise ValueError( + f"actor_update_dispatch_size={rl_config.actor_update_dispatch_size} " + f"must be >= mini_batch_size/actor_data_parallel_size " + f"({rl_config.mini_batch_size}/{actor_data_parallel_size}=" + f"{int(rl_config.mini_batch_size/actor_data_parallel_size)})" + ) # 检查验证器参数匹配 if len(rl_config.verifier_function) != len(rl_config.verifier_weight): diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 342846d4..fc94760f 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -211,15 +211,10 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns = ['responses', 'advantages', 'old_log_prob', 'input_ids', 'response_length', 'prompt_length'] else: experience_columns = ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', 'input_ids', 'response_length', 'prompt_length'] - if is_multimodal(): experience_columns.extend(['attention_mask', 'position_ids']) - experience_count = self.rl_config.actor_update_dispatch_size - else: - experience_count = ( - self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size() - ) - + + experience_count = self.rl_config.actor_update_dispatch_size if self.rl_config.filter_groups_enable: experience_count = ( self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt // -- Gitee From 15c01c23b3ac01f3e85618c30a903a082658ddb7 Mon Sep 17 00:00:00 2001 From: dangzzz <814207734@qq.com> Date: Thu, 4 Sep 2025 17:35:15 +0800 Subject: [PATCH 2/3] [bug_fix]actor_update_dispatch_size --- mindspeed_rl/workers/actor_hybrid_worker.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index fc94760f..d9553500 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -215,11 +215,6 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns.extend(['attention_mask', 'position_ids']) experience_count = self.rl_config.actor_update_dispatch_size - if self.rl_config.filter_groups_enable: - experience_count = ( - self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt // - self.parallel_state.get_data_parallel_world_size() - ) if skip_actor_log_prob: experience_columns.remove('old_log_prob') -- Gitee From ade1313f71564f60c0182687350244dba70e76e9 Mon Sep 17 00:00:00 2001 From: dangzzz <814207734@qq.com> Date: Mon, 8 Sep 2025 14:52:13 +0800 Subject: [PATCH 3/3] [bug_fix]actor_update_dispatch_size --- mindspeed_rl/config_cls/validate_config.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index e2056096..4a8e8826 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -350,6 +350,16 @@ def validate_rl_args( rl_config.critic_update_dispatch_size, "Critic Update") + # 若指定了自定义的actor_update_dispatch_size,检查 actor_update_dispatch_size 是否符合 on_policy/off_policy 策略要求 + if rl_config.actor_update_dispatch_size: + if rl_config.actor_update_dispatch_size < rl_config.mini_batch_size / actor_data_parallel_size: + raise ValueError( + f"actor_update_dispatch_size={rl_config.actor_update_dispatch_size} " + f"must be >= mini_batch_size/actor_data_parallel_size " + f"({rl_config.mini_batch_size}/{actor_data_parallel_size}=" + f"{int(rl_config.mini_batch_size/actor_data_parallel_size)})" + ) + if rl_config.filter_groups_enable: # 若开启dapo动态采样,update的gbs=filter_groups_train_batch_size rl_config.actor_update_dispatch_size = ( @@ -367,15 +377,6 @@ def validate_rl_args( _validate_experience_ratio(actor_config.global_batch_size, rl_config.actor_update_dispatch_size, "Actor Update") - - # 检查 actor_update_dispatch_size 是否符合 on_policy/off_policy 策略要求 - if rl_config.actor_update_dispatch_size < rl_config.mini_batch_size / actor_data_parallel_size: - raise ValueError( - f"actor_update_dispatch_size={rl_config.actor_update_dispatch_size} " - f"must be >= mini_batch_size/actor_data_parallel_size " - f"({rl_config.mini_batch_size}/{actor_data_parallel_size}=" - f"{int(rl_config.mini_batch_size/actor_data_parallel_size)})" - ) # 检查验证器参数匹配 if len(rl_config.verifier_function) != len(rl_config.verifier_weight): -- Gitee