From 9f0baaf270390f99de4871d15a5e0cd9aaf9a5e1 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Thu, 5 Jun 2025 15:06:56 +0800 Subject: [PATCH 01/12] . MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加Ref分离模式下的 sync_ref 功能 . . . . . . . . . . . . 。 . . . . . Update test_ref_update.py Update test_ref_update.py Update test_ref_update.py Update test_full_grpo_qwen25_7b_ref_update.sh Update test_grpo_trainer_qwen25_7b_ref_split.yaml 补充 ref 和 actor 单独的 prof 配置 Update test_grpo_trainer_qwen25_7b_ref_split.yaml Update train_grpo.py 。 . . --- .gitignore | 3 + cli/train_grpo.py | 20 +- mindspeed_rl/config_cls/rl_config.py | 23 +- mindspeed_rl/config_cls/validate_config.py | 26 +- mindspeed_rl/models/rollout/vllm_engine.py | 119 +++---- mindspeed_rl/trainer/base.py | 16 +- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 61 +++- mindspeed_rl/workers/actor_hybrid_worker.py | 28 +- mindspeed_rl/workers/base_worker.py | 7 + mindspeed_rl/workers/reference_woker.py | 122 ++++---- mindspeed_rl/workers/scheduler/launcher.py | 6 + ...test_grpo_trainer_qwen25_7b_ref_split.yaml | 184 +++++++++++ .../test_full_grpo_qwen25_7b_ref_update.sh | 34 ++ .../test_module_entry_ref_update.sh | 70 +++++ tests/st/ref_update/test_ref_update.py | 291 ++++++++++++++++++ 15 files changed, 860 insertions(+), 150 deletions(-) create mode 100644 tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml create mode 100644 tests/st/ref_update/test_full_grpo_qwen25_7b_ref_update.sh create mode 100644 tests/st/ref_update/test_module_entry_ref_update.sh create mode 100644 tests/st/ref_update/test_ref_update.py diff --git a/.gitignore b/.gitignore index f5a996cd..7114695a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,7 @@ # Byte-compiled / optimized / DLL files +outputs +tmp +runs __pycache__/ *.py[cod] *$py.class diff --git a/cli/train_grpo.py b/cli/train_grpo.py index fb687e17..9958b8ae 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -46,10 +46,10 @@ def train(config): MsProbe.config_init(msprobe_config) MsProbe.save_configs({ - 'actor': eval(str(actor_config.dict())), - 'ref': eval(str(ref_config.dict())), - 'reward': eval(str(reward_config.dict())), - 'rl': eval(str(rl_config.dict())), + 'actor': eval(str(actor_config.dict())), + 'ref': eval(str(ref_config.dict())), + 'reward': eval(str(reward_config.dict())), + 'rl': eval(str(rl_config.dict())), 'generate': eval(str(generate_config.dict())) }) @@ -87,6 +87,8 @@ def train(config): rl_config=rl_config, generate_config=generate_config, model_provider=gpt_model_provider, + profiler_config=profiler_config["actor"], + msprobe_config=msprobe_config, tokenizer=tokenizer, initialize_func=initialize_megatron, get_megatron_module=get_megatron_module, @@ -100,6 +102,8 @@ def train(config): rl_config=rl_config, generate_config=generate_config, model_provider=gpt_model_provider, + profiler_config=profiler_config["reference"], + msprobe_config=msprobe_config, tokenizer=tokenizer, initialize_func=initialize_megatron, get_megatron_module=get_megatron_module, @@ -227,6 +231,14 @@ def parse_training_config(config: Dict): config.get("profiler_config", {}).get("integrated", {}), role="integrated" ), + "actor": ProfilerConfig( + config.get("profiler_config", {}).get("actor", {}), + role="actor" + ), + "reference": ProfilerConfig( + config.get("profiler_config", {}).get("reference", {}), + role="reference" + ), }) actor_config.max_prompt_length = rl_config.max_prompt_length diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 44e65d3d..f58f55fd 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -2,7 +2,7 @@ from mindspeed_rl.config_cls.base_config import BaseConfig class RLConfig(BaseConfig): - ''' + """ RL configuration class. Initialize model configuration from the provided config dictionary. All instance attributes are initialized using the dictionary keys. @@ -38,6 +38,10 @@ class RLConfig(BaseConfig): adv_dispatch_size: experience count every forward step for advantages (default: same as experience_count_all) actor_update_dispatch_size: experience count every forward step for actor update (default: same as experience_count_all) + sync_ref_model: Whether to enable reference model update (default: False) + ref_model_sync_steps: Interval step of implement update (default: 128) + ref_model_mixup_alpha: Update reference weight mixup by `π_ref = α * π_θ + (1 - α) * π_ref_prev` (default: α = 0.9) + shuffle_mini_batch: Whether to shuffle minibatch (default: False) n_samples_per_prompt: Number of samples per prompt (default: 1) enable_sharding_validate: Whether to enable sharding validation (default: False) @@ -51,10 +55,10 @@ class RLConfig(BaseConfig): num_cpus_for_local_task: Number of CPUs for local ray task (default: 1) num_cpus_for_placement_group: Number of CPUs for ray worker placement group # Default values can still be defined if no config is provided - ''' + """ def __init__(self, config_dict): - self.runtime_env_path = 'configs/envs/runtime_env.yaml' + self.runtime_env_path = "configs/envs/runtime_env.yaml" self.rule_reward = True self.beta = 0.1 self.actor_resource = None @@ -69,11 +73,11 @@ class RLConfig(BaseConfig): self.lam = 0.95 self.advantage_whiten = True self.kl_penalty = "low_var_kl" - self.kl_ctrl_type = 'fixed' + self.kl_ctrl_type = "fixed" self.init_kl_coef = 0.01 self.kl_horizon = 1000 self.kl_target = 100.0 - self.adv_estimator = 'group_norm' + self.adv_estimator = "group_norm" self.verifier_function = ["base_acc", ] self.verifier_weight = [1.0, ] self.verifier_parallel = 1 @@ -104,7 +108,12 @@ class RLConfig(BaseConfig): self.adv_dispatch_size = None self.actor_update_dispatch_size = None + self.sync_ref_model = False + self.ref_model_sync_steps = 128 + self.ref_model_mixup_alpha = 0.9 + self.ref_save = None + self.update(config_dict) - self.n_samples_per_prompt = config_dict.get('n_samples_per_prompt', 1) - self.mini_batch_size = config_dict.get('mini_batch_size', 1) * self.n_samples_per_prompt + self.n_samples_per_prompt = config_dict.get("n_samples_per_prompt", 1) + self.mini_batch_size = config_dict.get("mini_batch_size", 1) * self.n_samples_per_prompt diff --git a/mindspeed_rl/config_cls/validate_config.py b/mindspeed_rl/config_cls/validate_config.py index e037bcd8..6c4c450c 100644 --- a/mindspeed_rl/config_cls/validate_config.py +++ b/mindspeed_rl/config_cls/validate_config.py @@ -3,6 +3,7 @@ import os from mindspeed_rl.config_cls.rl_config import RLConfig from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.models import actor def validate_rl_args( @@ -61,7 +62,7 @@ def validate_rl_args( reward_config.pipeline_model_parallel_size, reward_config.context_parallel_size, "Reward") - + # 校验批次大小与微批次关系 def _validate_batch_ratio(global_batch, micro_batch, n_samples, component): if (global_batch * n_samples) % micro_batch != 0: @@ -92,7 +93,7 @@ def validate_rl_args( raise ValueError( f"{component} global_batch_size {global_batch_size} " f"must be divisible by data_parallel_size {data_parallel}") - + if (global_batch_size // data_parallel * n_samples) % micro_batch_size != 0: raise ValueError( f"{component} global_batch_size {actor_config.global_batch_size} " @@ -211,7 +212,28 @@ def validate_rl_args( _validate_experience_ratio(reward_config.global_batch_size, rl_config.actor_update_dispatch_size, "Actor Update") + if rl_config.sync_ref_model: + if rl_config.ref_model_sync_steps <= 0: + raise ValueError( + f"Reference update steps {rl_config.ref_model_sync_steps} " + f"must greater than 0." + ) + if rl_config.ref_model_mixup_alpha <= 0 or rl_config.ref_model_mixup_alpha > 1: + raise ValueError( + f"Reference update mixup ratio {rl_config.ref_model_mixup_alpha} " + f"make sure 0 < alpha <= 1" + ) + if not rl_config.use_integrated_worker: + check_tp = actor_config.tensor_model_parallel_size == ref_config.tensor_model_parallel_size + check_pp = actor_config.pipeline_model_parallel_size == ref_config.pipeline_model_parallel_size + str_tp = "" if check_tp else f"Actor TP({actor_config.tensor_model_parallel_size}) \ + mismatch Reference TP({ref_config.tensor_model_parallel_size})" + str_pp = "" if check_pp else f"Actor TP({actor_config.pipeline_model_parallel_size}) \ + mismatch Reference TP({ref_config.pipeline_model_parallel_size})" + if not (check_tp and check_pp): + raise ValueError(str_tp + str_pp) + # 检查验证器参数匹配 if len(rl_config.verifier_function) != len(rl_config.verifier_weight): raise ValueError( diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 649f3698..492996db 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -23,7 +23,7 @@ def dummy_compile(*compile_args, **compile_kwargs): return wrapper return decorate - + torch.jit.script = dummy_compile torch.compile = dummy_compile @@ -32,11 +32,7 @@ from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.models.base.base_inference_engine import BaseInferEngine from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import initialize_parallel_state -from mindspeed_rl.models.rollout.vllm_adapter.megatron_weight_loaders import ( - load_megatron_weights, - update_megatron_weight_loader, - InferParallelConfig -) +from mindspeed_rl.models.rollout.vllm_adapter.megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader, InferParallelConfig from mindspeed_rl.utils import get_tokenizer logger = Loggers("vllm_engine") @@ -44,27 +40,27 @@ logger = Loggers("vllm_engine") class VLLMInferEngine(BaseInferEngine): def __init__( - self, - tokenizer_name_or_path: str, - train_tensor_parallel_size: int, - train_pipeline_parallel_size: int, - train_expert_parallel_size: int, - infer_tensor_parallel_size: int, - infer_pipeline_parallel_size: int, - infer_expert_parallel_size: int, - megatron_config: MegatronConfig, - sampling_config: dict, - prompt_type: str = None, - prompt_type_path: str = None, - enable_prefix_caching: bool = False, - num_scheduler_steps: int = 1, - max_num_seqs: int = 1, - max_model_len: int = 2048, - dtype: str = "bfloat16", - gpu_memory_utilization: float = 0.5, - trust_remote_code: bool = True, - load_format: str = "megatron", - **kwargs + self, + tokenizer_name_or_path: str, + train_tensor_parallel_size: int, + train_pipeline_parallel_size: int, + train_expert_parallel_size: int, + infer_tensor_parallel_size: int, + infer_pipeline_parallel_size: int, + infer_expert_parallel_size: int, + megatron_config: MegatronConfig, + sampling_config: dict, + prompt_type: str = None, + prompt_type_path: str = None, + enable_prefix_caching: bool = False, + num_scheduler_steps: int = 1, + max_num_seqs: int = 1, + max_model_len: int = 2048, + dtype: str = "bfloat16", + gpu_memory_utilization: float = 0.5, + trust_remote_code: bool = True, + load_format: str = "megatron", + **kwargs, ): """ Initialize the VLLM inference engine. @@ -102,37 +98,31 @@ class VLLMInferEngine(BaseInferEngine): max_model_len=max_model_len, dtype=dtype, gpu_memory_utilization=gpu_memory_utilization, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, ) # Additional initialization logic for VLLMInferEngine # Initialize sampling parameters from SamplingConfig self.sampling_config = sampling_config try: self.sampling_params = SamplingParams( - n=sampling_config.get('num_completions', 1), - logprobs=sampling_config.get('logprobs', 1), - max_tokens=sampling_config.get('max_tokens', 128), - best_of=sampling_config.get('best_of', 2), - top_p=sampling_config.get('top_p', 1.0), - top_k=sampling_config.get('top_k', 50), - min_p=sampling_config.get('min_p', 0.0), - temperature=sampling_config.get('temperature', 0.2), - detokenize=sampling_config.get('detokenize', False), - seed=sampling_config.get('seed', None) + n=sampling_config.get("num_completions", 1), + logprobs=sampling_config.get("logprobs", 1), + max_tokens=sampling_config.get("max_tokens", 128), + best_of=sampling_config.get("best_of", 2), + top_p=sampling_config.get("top_p", 1.0), + top_k=sampling_config.get("top_k", 50), + min_p=sampling_config.get("min_p", 0.0), + temperature=sampling_config.get("temperature", 0.2), + detokenize=sampling_config.get("detokenize", False), + seed=sampling_config.get("seed", None), ) except Exception as e: raise ValueError(f"Error creating SamplingParams from dictionary") from e - self.hf_config = AutoConfig.from_pretrained( - tokenizer_name_or_path, - trust_remote_code=trust_remote_code - ) + self.hf_config = AutoConfig.from_pretrained(tokenizer_name_or_path, trust_remote_code=trust_remote_code) - self.tokenizer = get_tokenizer(tokenizer_name_or_path, - prompt_type=prompt_type, prompt_type_path=prompt_type_path) - self.pad_token_id = ( - self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None - else self.tokenizer.tokenizer.eos_token_id) + self.tokenizer = get_tokenizer(tokenizer_name_or_path, prompt_type=prompt_type, prompt_type_path=prompt_type_path) + self.pad_token_id = self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None else self.tokenizer.tokenizer.eos_token_id # Set up local rank using the helper function self.local_rank = get_local_rank() @@ -140,8 +130,8 @@ class VLLMInferEngine(BaseInferEngine): # Initialize parallel state if tensor parallel size is specified if train_tensor_parallel_size is not None: num_tp_per_train_tp = train_tensor_parallel_size // infer_tensor_parallel_size - os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' - os.environ['MEGATRON_IMPORT_TIMERS'] = '0' + os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" + os.environ["MEGATRON_IMPORT_TIMERS"] = "0" initialize_parallel_state( infer_tensor_model_parallel_size=infer_tensor_parallel_size, train_tensor_model_parallel_size=train_tensor_parallel_size, @@ -169,7 +159,7 @@ class VLLMInferEngine(BaseInferEngine): skip_tokenizer_init=False, gpu_memory_utilization=gpu_memory_utilization, max_num_seqs=max_num_seqs, - max_model_len=max_model_len + max_model_len=max_model_len, ) self.model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() @@ -221,13 +211,9 @@ class VLLMInferEngine(BaseInferEngine): for name, params in self.model.named_parameters(): params.data = self.cpu_model[name] - def sync_model_weights(self, params, load_format='megatron'): - infer_parallel_config = InferParallelConfig(self.infer_tensor_parallel_size, self.infer_pipeline_parallel_size, - self.infer_expert_parallel_size) - load_megatron_weights(params, - self.model, - infer_parallel_config, - self.hf_config) + def sync_model_weights(self, params, load_format="megatron"): + infer_parallel_config = InferParallelConfig(self.infer_tensor_parallel_size, self.infer_pipeline_parallel_size, self.infer_expert_parallel_size) + load_megatron_weights(params, self.model, infer_parallel_config, self.hf_config) if hasattr(self.model.model.layers[0].self_attn, "mla_attn"): self._process_mla() @@ -238,17 +224,11 @@ class VLLMInferEngine(BaseInferEngine): mla.w_kc = None mla.w_vc = None - @torch.no_grad() @mstx_timer_decorator def generate_sequences(self, idx_list, **kwargs): self.init_cache_engine() with self.update_sampling_params(**kwargs): - response = self.llm.generate( - prompts=None, - sampling_params=self.sampling_params, - prompt_token_ids=idx_list, - use_tqdm=False - ) + response = self.llm.generate(prompts=None, sampling_params=self.sampling_params, prompt_token_ids=idx_list, use_tqdm=False) outs = self._post_process_outputs(response) self.free_cache_engine() return outs @@ -269,11 +249,9 @@ class VLLMInferEngine(BaseInferEngine): logprob.append(logprobs_dict[token_id].logprob) logprobs.append(torch.tensor(logprob)) - output_token_ids = pad_sequence(output_token_ids, batch_first=True, - padding_value=self.pad_token_id) + output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=self.pad_token_id) if len(logprobs) > 0: - logprobs = pad_sequence(logprobs, batch_first=True, - padding_value=self.pad_token_id) + logprobs = pad_sequence(logprobs, batch_first=True, padding_value=self.pad_token_id) return output_token_ids, logprobs @contextmanager @@ -292,10 +270,7 @@ class VLLMInferEngine(BaseInferEngine): setattr(self.sampling_params, key, value) def chat(self, conversation, sampling_params=None): - outputs = self.llm.chat( - conversation, - sampling_params=sampling_params if sampling_params else self.sampling_params, - use_tqdm=False) + outputs = self.llm.chat(conversation, sampling_params=sampling_params if sampling_params else self.sampling_params, use_tqdm=False) return outputs diff --git a/mindspeed_rl/trainer/base.py b/mindspeed_rl/trainer/base.py index e05e53a0..e5aaf8ed 100644 --- a/mindspeed_rl/trainer/base.py +++ b/mindspeed_rl/trainer/base.py @@ -26,8 +26,8 @@ class RayBaseTrainer(object): lam: float = 0.95, adv_estimator: str = "group_norm", missing_eos_penalty: float = 1.0, - kl_penalty: str = 'low_var_kl', - kl_ctrl_type: str = 'fixed', + kl_penalty: str = "low_var_kl", + kl_ctrl_type: str = "fixed", kl_horizon: int = 1000, kl_target: float = 100.0, init_kl_coef: float = 0.001, @@ -65,12 +65,17 @@ class RayBaseTrainer(object): self.num_cpus_for_local_task = num_cpus_for_local_task self.kwargs = kwargs + + self.sync_ref_model = kwargs.get("sync_ref_model", False) + self.ref_model_sync_steps = kwargs.get("ref_model_sync_steps") + self.ref_model_mixup_alpha = kwargs.get("ref_model_mixup_alpha") + # define KL control - if kl_ctrl_type == 'fixed': + if kl_ctrl_type == "fixed": self.kl_ctrl = FixedKLController(init_kl_coef=self.init_kl_coef) - elif kl_ctrl_type == 'adaptive': + elif kl_ctrl_type == "adaptive": if self.kl_horizon <= 0: - raise ValueError(f'horizon must be larger than 0. Got {self.kl_horizon}') + raise ValueError(f"horizon must be larger than 0. Got {self.kl_horizon}") self.kl_ctrl = AdaptiveKLController(init_kl_coef=init_kl_coef, target_kl=kl_target, horizon=kl_horizon) @@ -79,6 +84,7 @@ class RayBaseTrainer(object): self.wandb = None self.tensorboard = None + if kwargs.get("use_wandb", ""): self.wandb = WandbLogger(kwargs) if kwargs.get("use_tensorboard", "") and self.wandb is None: diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index a38d0bb9..d980930b 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -95,7 +95,7 @@ class RayGRPOTrainer(RayBaseTrainer): self.set_actor_log_prob_skip_flag() def transfer_dock_init(self): - self.transfer_dock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, + self.transfer_dock: GRPOTransferDock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, self.metrics, addition_columns=self.dataset_additional_keys) self.actor_worker.sync_init_transfer_dock(self.transfer_dock) self.ref_worker.sync_init_transfer_dock(self.transfer_dock) @@ -156,7 +156,7 @@ class RayGRPOTrainer(RayBaseTrainer): self.actor_worker.compute_log_prob(blocking=self.blocking) self.actor_worker.wait_all_ref_objs_run_over() - + self.ref_worker.wait_all_ref_objs_run_over() for reward in self.reward_list: if hasattr(reward, 'wait_all_ref_objs_run_over'): @@ -188,8 +188,13 @@ class RayGRPOTrainer(RayBaseTrainer): self.tensorboard.add_scalar(f"train/{k}", v, iteration) if self.wandb is not None: self.wandb.log_metrics(metrics.metric, iteration) + #* Sync reference weight + if self.sync_ref_model and (iteration + 1) % self.ref_model_sync_steps == 0: + self.update_ref(iteration) if iteration % self.save_interval == 0: self.save_checkpoint(iteration) + if self.actor_worker.rl_config.ref_save is not None: + self.save_ref_checkpoint(iteration) logger.info('after grpo training is done') ray.shutdown() @@ -213,11 +218,11 @@ class RayGRPOTrainer(RayBaseTrainer): end_adv_time = time.time() ray.get( self.transfer_dock.update_metrics.remote( - "timing/adv", + "timing/adv", value=[round(end_adv_time, 4), round(start_adv_time, 4)], cumulate=True ) - ) + ) ray.get( self.transfer_dock.update_metrics.remote( "end_time/end_adv_time", @@ -228,3 +233,51 @@ class RayGRPOTrainer(RayBaseTrainer): def save_checkpoint(self, iteration: int): self.actor_worker.save_checkpoint(iteration) + + def save_ref_checkpoint(self, iteration: int): + self.ref_worker.save_ref_checkpoint(iteration) + + def update_ref(self, iteration=-1): + cur_iteration = iteration + + logger = Loggers("update_ref") + + logger.info(f"Start mixup ref and actor model with ratio={self.ref_model_mixup_alpha}") + + #* Step 1 clean worker + self.actor_worker.wait_all_ref_objs_run_over() + self.ref_worker.wait_all_ref_objs_run_over() + + #* Step 2 prepare weight + ref_sync_objs = [] + for ref in self.ref_worker.actor_handlers: + ref_sync_objs.append(ref.pre_ref_update.remote()) + list_data_parallel_global_ranks = ray.get(ref_sync_objs) + list_dp0_rank = [ranks[0] for ranks in list_data_parallel_global_ranks] + + #* Step 3 offer download command + with Timer(name="sync ref", logger=None) as timer: + logger.info(f"Using DP rank list: {list_dp0_rank}") + list_main_actor_handlers = [self.actor_worker.actor_handlers[i] for i in list_dp0_rank] + list_main_ref_handlers = [self.ref_worker.actor_handlers[i] for i in list_dp0_rank] + + #* transmit + actor_sync_objs = [] + for actor, ref in zip(list_main_actor_handlers, list_main_ref_handlers): + actor_sync_objs.append(actor.sync_weight_to_ref.remote(ref)) + ray.get(actor_sync_objs) + + logger.info(f"Step {cur_iteration}, sync ref model with actor model, " + f"alpha: {self.ref_model_mixup_alpha}, time: {timer.last:.2f}s.") + + #* Step 4 Finish weight sync + logger.info(f"post ref update") + ref_sync_objs = [] + for ref in self.ref_worker.actor_handlers: + ref_sync_objs.append(ref.post_ref_update.remote()) + ray.get(ref_sync_objs) + + #* Step 5 Clean worker + self.actor_worker.wait_all_ref_objs_run_over() + self.ref_worker.wait_all_ref_objs_run_over() + logger.info("Finished sync ref.") \ No newline at end of file diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 4cdbb73b..6b08c6be 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -1,11 +1,13 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +from multiprocessing.managers import ListProxy import time import dataclasses import copy from typing import Callable import ray +import ray.actor from torch import nn import torch from transformers import AutoConfig @@ -229,7 +231,7 @@ class ActorHybridWorkerBase(BaseWorker): actor_generate_profiler = profiler_start(self.profiler_config, role="actor_generate", profiler_iteration=self.prof_iteration) MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences') - + start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( @@ -350,6 +352,30 @@ class ActorHybridWorkerBase(BaseWorker): profiler_step(actor_compute_log_prob_profiler) MsProbe.debugger_stop('actor_compute_log_prob') + @mstx_timer_decorator + def sync_weight_to_ref(self, peer_handler: ray.actor.ActorHandle): + from mindspeed_rl.utils.loggers import Loggers + import logging + import math + + logger = Loggers(f"sync_weight_to_ref_rank", logger_level=logging.DEBUG) + + self.empty_cache() + self.sharding_manager.megatron_offloader.onload_param() + + list_param = list(self.actor_hybrid.train_actor.model[0].named_parameters()) + num_param = len(list_param) + + peer_rank = ray.get(peer_handler.get_value.remote("_local_rank")) + for i, (key, val) in enumerate(list_param): + size = math.prod(val.shape) * val.dtype.itemsize / 1024 / 1024 + logger.info(f"NPU: {self._local_rank} --> {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB", key) + + ray.get(peer_handler.receive_from_actor.remote(key, val.cpu())) + + logger.info(f"ActorHybridWorker (rank: {self._local_rank}) sync done") + self.sharding_manager.megatron_offloader.offload_param() + def _build_model_optimizer(self): actor_module, optimizer, opt_param_scheduler = self.setup_model_and_optimizer( self.model_provider, self.model_type.encoder_or_decoder) diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index 51d91a36..2b8ccf31 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -118,6 +118,9 @@ class BaseRayWorker: def get_master_addr_port(self): return self._master_addr, self._master_port + def get_value(self, key): + return getattr(self, key) + class BaseWorker(BaseRayWorker, ABC): """基类,封装通用逻辑但保留子类接口和装饰器""" @@ -203,6 +206,10 @@ class BaseWorker(BaseRayWorker, ABC): def init_transfer_dock(self, td): raise NotImplementedError("This method should be implemented by subclasses") + @property + def data_parallel_world_size(self): + return self.parallel_state.get_data_parallel_world_size() + @property def td(self): """ diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 942582bb..7c2b9bb5 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -9,11 +9,13 @@ import torch from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.config_cls.rl_config import RLConfig -from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.config_cls.generate_config import GenerateConfig from mindspeed_rl.models.reference import Reference +from mindspeed_rl.trainer.utils import parallel_state from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.base_worker import BaseWorker +from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronOffLoader from mindspeed_rl.utils.compute import get_parallel_state from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank from mindspeed_rl.utils.utils import mstx_timer_decorator @@ -35,25 +37,18 @@ class ReferenceWorkerBase(BaseWorker): """ def __init__( - self, - megatron_config: MegatronConfig, - rl_config: RLConfig, - generate_config: GenerateConfig, - model_provider: Callable, - initialize_func: Callable, - tokenizer: BaseTokenizer = None, - get_megatron_module: Callable = None, - **kwargs + self, + megatron_config: MegatronConfig, + rl_config: RLConfig, + generate_config: GenerateConfig, + model_provider: Callable, + initialize_func: Callable, + tokenizer: BaseTokenizer = None, + get_megatron_module: Callable = None, + **kwargs ): super().__init__( - megatron_config, - rl_config, - generate_config, - model_provider=model_provider, - initialize_func=initialize_func, - tokenizer=tokenizer, - get_megatron_module=get_megatron_module, - **kwargs + megatron_config, rl_config, generate_config, model_provider=model_provider, initialize_func=initialize_func, tokenizer=tokenizer, get_megatron_module=get_megatron_module, **kwargs ) self.reference = None @@ -62,12 +57,13 @@ class ReferenceWorkerBase(BaseWorker): self.model = self.get_model(self.model_provider, self.model_type, wrap_with_ddp=False) if self.megatron_config.load is not None or self.megatron_config.pretrained_checkpoint is not None: - self.megatron_config.iteration, self.megatron_config.num_floating_point_operations_so_far = self.load_checkpoint( - self.model, None, None) + self.megatron_config.iteration, self.megatron_config.num_floating_point_operations_so_far = self.load_checkpoint(self.model, None, None) else: self.megatron_config.iteration = 0 self.megatron_config.num_floating_point_operations_so_far = 0 + self.megatron_offloader = MegatronOffLoader(self.model, wrap_with_ddp=False) + self.reference = Reference( self.model, beta=self.rl_config.beta, @@ -78,7 +74,7 @@ class ReferenceWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - temperature=self.generate_config.sampling_config["temperature"] + temperature=self.generate_config.sampling_config["temperature"], ) def init_transfer_dock(self, td): @@ -86,32 +82,26 @@ class ReferenceWorkerBase(BaseWorker): @mstx_timer_decorator def compute_ref_log_prob(self): - experience_consumer_stage = 'ref_log_prob' - experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] + experience_consumer_stage = "ref_log_prob" + experience_columns = ["input_ids", "responses", "response_length", "prompt_length"] experience_count = self.rl_config.ref_dispatch_size - sorted_indexes = self.get_dp_range_indexes(experience_count, - use_vllm=False) if self.rl_config.guarantee_order else None + sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: - batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, - experience_columns, - experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop( - 0) if self.rl_config.guarantee_order else None, - get_n_samples=False) - + batch_data, index = self.dispatch_transfer_dock_data( + experience_consumer_stage, + experience_columns, + experience_count, + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, + get_n_samples=False, + ) + if not start_time_defined: start_time = time.time() start_time_defined = True - ray.get( - self.td.update_metrics.remote( - "start_time/reference_model", - value=[round(start_time, 4)], - cumulate=True - ) - ) + ray.get(self.td.update_metrics.remote("start_time/reference_model", value=[round(start_time, 4)], cumulate=True)) if batch_data and index: output, batch = self.reference.compute_log_prob(batch_data) @@ -119,31 +109,53 @@ class ReferenceWorkerBase(BaseWorker): # only on last rank. It should be on every tp rank log_probs = torch.cat(output, dim=0) # (bs, seq_size) log_probs = log_probs.to(torch.float32) - log_probs = truncate_rows(log_probs, batch['response_length']) - output = {'ref_log_prob': log_probs} + log_probs = truncate_rows(log_probs, batch["response_length"]) + output = {"ref_log_prob": log_probs} self.collect_transfer_dock_data(output, index) end_time = time.time() - ray.get( - self.td.update_metrics.remote( - "timing/reference_model", - value=[round(end_time, 4), round(start_time, 4)], - cumulate=True - ) - ) + ray.get(self.td.update_metrics.remote("timing/reference_model", value=[round(end_time, 4), round(start_time, 4)], cumulate=True)) parallel_state = get_parallel_state() use_vllm = False if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: ref_end_time = time.time() - ray.get( - self.td.update_metrics.remote( - "end_time/reference", - value=[round(ref_end_time, 4)] - ) - ) + ray.get(self.td.update_metrics.remote("end_time/reference", value=[round(ref_end_time, 4)])) logger.info("finish compute ref log prob") self.empty_cache() + def save_ref_ckpt(self, iteration: int): + from megatron.training import get_args + + args = get_args() + + original_save_path = args.save + args.save = self.rl_config.ref_save + + self.save_checkpoint(iteration, self.model, None, None, None) + + args.save = original_save_path + + def pre_ref_update(self): + self.megatron_offloader.offload_param() + + self.ref_params = dict(self.reference.model[0].named_parameters()) + + return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS + + def receive_from_actor(self, layer_name: str, tensor_weight: torch.Tensor): + ref_key = layer_name.replace("module.module.", "module.") + + alpha = self.rl_config.ref_model_mixup_alpha + self.ref_params[ref_key].mul_(1 - alpha).add_(tensor_weight, alpha=alpha) + + def post_ref_update(self): + self.megatron_offloader.onload_param() + + if len(get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS) > 1: + rank_dp0 = get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS[0] + for _, v in self.ref_params.items(): + torch.distributed.broadcast(v, src=rank_dp0, group=get_parallel_state()._DATA_PARALLEL_GROUP) + @ray.remote(resources={"NPU": 0.3}) class ReferenceWorker(ReferenceWorkerBase): diff --git a/mindspeed_rl/workers/scheduler/launcher.py b/mindspeed_rl/workers/scheduler/launcher.py index c95c60a0..9ebb8425 100644 --- a/mindspeed_rl/workers/scheduler/launcher.py +++ b/mindspeed_rl/workers/scheduler/launcher.py @@ -297,6 +297,12 @@ class RayActorGroup: actor_train_objs.append(actor.save_ckpt.remote(iteration)) return ray.get(actor_train_objs) + def save_ref_checkpoint(self, iteration): + actor_train_objs = [] + for actor in self.actor_handlers: + actor_train_objs.append(actor.save_ref_ckpt.remote(iteration)) + return ray.get(actor_train_objs) + def initialize(self): for actor in self.actor_handlers: self.temp_actor_ref_objs.append(actor.initialize.remote()) diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml new file mode 100644 index 00000000..5d6801be --- /dev/null +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml @@ -0,0 +1,184 @@ +defaults: + - model: + - qwen25_7b + +megatron_training: + model: qwen25_7b + use_fused_rmsnorm: true + use_mcore_models: true + sequence_parallel: true + use_flash_attn: true + no_masked_softmax_fusion: true + attention_softmax_in_fp32: true + no_gradient_accumulation_fusion: true + use_fused_swiglu: true + use_fused_rotary_pos_emb: true + bf16: true + use_distributed_optimizer: true + tokenizer_type: PretrainedFromHF + tokenizer_name_or_path: /workspace/weight/Qwen2.5-7B-Instruct + global_batch_size: 4 + seq_length: 1024 + save_interval: 50 + train_iters: 1 + stage: ray_grpo + attention_dropout: 0.0 + init_method_std: 0.01 + hidden_dropout: 0.0 + distributed_backend: nccl + no_shared_storage: true + variable_seq_lengths: true + dataset_additional_keys: ['labels',] + data_path: /workspace/dataset/pe_nlp/pe_nlp + split: 100,0,0 + +actor_config: + model: qwen25_7b + micro_batch_size: 4 + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 1 + lr: 5e-7 + lr_decay_style: cosine + min_lr: 5e-8 + weight_decay: 0.0 + lr_warmup_fraction: 0.0 + clip_grad: 1 + adam_beta1: 0.9 + adam_beta2: 0.95 + initial_loss_scale: 4096 + finetune: true + load: /workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 + save: ./ckpt + no_load_optim: true + no_load_rng: true + +ref_config: + model: qwen25_7b + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 2 + micro_batch_size: 8 + load: /workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 + no_load_optim: true + no_load_rng: true + +reward_config: + model: qwen25_7b + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 2 + micro_batch_size: 8 + load: /workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 + no_load_optim: true + no_load_rng: true + +rl_config: + use_integrated_worker: false + guarantee_order: false + blocking: true + gamma: 1.0 + lam: 0.95 + adv_estimator: group_norm + kl_penalty: low_var_kl + kl_ctrl_type: fixed + init_kl_coef: 0.01 + mini_batch_size: 64 + max_prompt_length: 1024 + epochs: 1 + clip_ratio: 0.2 + entropy_coeff: 0.001 + shuffle_minibatch: false + n_samples_per_prompt: 8 + rule_reward: true + verifier_function: ["base_acc"] + verifier_weight: [1.0] + verifier_parallel: 4 + verifier_timeout: 120 + colocate_actor_ref: false + colocate_all_models: false + use_tensorboard: true + + sync_ref_model: true + ref_model_sync_steps: 25 + ref)model_mixup_alpha: 0.9 + ref_save: /workspace/weight/ref_save + + + + actor_resource: + num_npus: 4 + reference_resource: + num_npus: 4 + +generate_config: + + trust_remote_code: true + offload_train_optimizer: true + offload_train_grad: true + offload_train_param: true + + # 推理时的并行配置 + infer_tensor_parallel_size: 2 + infer_pipeline_parallel_size: 1 + infer_expert_parallel_size: 1 + + # vllm 模型相关设置 + max_num_seqs: 512 + max_model_len: 4096 + dtype: "bfloat16" + gpu_memory_utilization: 0.9 + num_scheduler_steps: 1 + + # 采样配置 + sampling_config: + logprobs: 1 + max_tokens: 2048 + top_p: 1 + top_k: -1 + min_p: 0.0 + seed: 1234 + temperature: 1.0 + detokenize: false + +profiler_config: + actor: + profile: false + mstx: false + stage: actor_update + profile_save_path: ./profiler_data + profile_export_type: db + profile_step_start: 1 + profile_step_end: 2 + profile_level: level1 + profile_with_memory: true + profile_record_shapes: false + profile_with_cpu: true + profile_with_npu: true + profile_with_module: false + profile_analysis: true + profile_ranks: [0, 1] + + reference: + profile: false + mstx: false + stage: ref + profile_save_path: ./profiler_data + profile_export_type: db + profile_step_start: 1 + profile_step_end: 2 + profile_level: level1 + profile_with_memory: true + profile_record_shapes: false + profile_with_cpu: true + profile_with_npu: true + profile_with_module: false + profile_analysis: true + profile_ranks: [0, 1] + +msprobe_config: + msprobe: false + dump_path: "./msprobe_dump" + key_data_dump: true + configurations_dump: true + actor_train_dump: true + reference_dump: true + step_start: 0 + step_end: 0 \ No newline at end of file diff --git a/tests/st/ref_update/test_full_grpo_qwen25_7b_ref_update.sh b/tests/st/ref_update/test_full_grpo_qwen25_7b_ref_update.sh new file mode 100644 index 00000000..f366a755 --- /dev/null +++ b/tests/st/ref_update/test_full_grpo_qwen25_7b_ref_update.sh @@ -0,0 +1,34 @@ +# pkill -9 python +ray stop --force +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HCCL_DETERMINISTIC=True + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +export PYTHONPATH=$SCRIPT_DIR/../:$PYTHONPATH +PROJECT_PATH=$SCRIPT_DIR/../../.. +PROFILER_DATA_PATH=$PROJECT_PATH/ci/profiler_data +rm -rf $PROFILER_DATA_PATH + +mkdir -p $PROJECT_PATH/logs + +export GLOO_SOCKET_IFNAME=enp189s0f0 +export RAY_DEBUG_POST_MORTEM=1 + +python $PROJECT_PATH/cli/train_grpo.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=test_grpo_trainer_qwen25_7b_ref_split \ + actor_config.tensor_model_parallel_size=2 \ + actor_config.pipeline_model_parallel_size=2 \ + actor_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 \ + ref_config.tensor_model_parallel_size=2 \ + ref_config.pipeline_model_parallel_size=2 \ + ref_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-7B-Instruct \ + rl_config.actor_resource.num_npus=4 \ + rl_config.reference_resource.num_npus=4 \ + megatron_training.train_iters=500 \ + rl_config.ref_save=/workspace/weight/ref_save \ + actor_config.save=/workspace/weight/actor_save \ + megatron_training.save_interval=100 \ + rl_config.ref_model_sync_steps=100 \ + 2>&1 | tee $PROJECT_PATH/logs/test_update_ref_full.log diff --git a/tests/st/ref_update/test_module_entry_ref_update.sh b/tests/st/ref_update/test_module_entry_ref_update.sh new file mode 100644 index 00000000..ae8251d0 --- /dev/null +++ b/tests/st/ref_update/test_module_entry_ref_update.sh @@ -0,0 +1,70 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HCCL_DETERMINISTIC=True + + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +export PYTHONPATH=$SCRIPT_DIR/../:$PYTHONPATH +PROJECT_PATH=$SCRIPT_DIR/../../.. +PROFILER_DATA_PATH=$PROJECT_PATH/ci/profiler_data +rm -rf $PROFILER_DATA_PATH + +mkdir -p $PROJECT_PATH/logs + +export GLOO_SOCKET_IFNAME=enp189s0f0 +export RAY_DEBUG_POST_MORTEM=1 + +case1=on +case2=on +case3=on + + +if [[ $case1 == on ]];then +#* Actor dp2, Ref dp2, 8卡同步 + +python $SCRIPT_DIR/test_ref_update.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=test_grpo_trainer_qwen25_7b_ref_split \ + actor_config.tensor_model_parallel_size=2 \ + actor_config.pipeline_model_parallel_size=1 \ + actor_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp1 \ + ref_config.tensor_model_parallel_size=2 \ + ref_config.pipeline_model_parallel_size=1 \ + ref_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp1 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-7B-Instruct \ + rl_config.actor_resource.num_npus=4 \ + rl_config.reference_resource.num_npus=4 \ + 2>&1 | tee $PROJECT_PATH/logs/test_ref_update_case1.log +fi + + +if [[ $case2 == on ]];then +#* Actor dp1, Ref dp1, 8卡同步 + +python $SCRIPT_DIR/test_ref_update.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=test_grpo_trainer_qwen25_7b_ref_split \ + actor_config.tensor_model_parallel_size=2 \ + actor_config.pipeline_model_parallel_size=2 \ + actor_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 \ + ref_config.tensor_model_parallel_size=2 \ + ref_config.pipeline_model_parallel_size=2 \ + ref_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp2 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-7B-Instruct \ + rl_config.actor_resource.num_npus=4 \ + rl_config.reference_resource.num_npus=4 \ + 2>&1 | tee $PROJECT_PATH/logs/test_ref_update_case2.log +fi + + +if [[ $case3 == on ]];then +#* Actor dp1, Ref dp1, 4卡同步 + +python $SCRIPT_DIR/test_ref_update.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=test_grpo_trainer_qwen25_7b_ref_split \ + actor_config.tensor_model_parallel_size=2 \ + actor_config.pipeline_model_parallel_size=1 \ + actor_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp1 \ + ref_config.tensor_model_parallel_size=2 \ + ref_config.pipeline_model_parallel_size=1 \ + ref_config.load=/workspace/weight/Qwen2.5-7B-Instruct-tp2pp1 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-7B-Instruct \ + rl_config.actor_resource.num_npus=2 \ + rl_config.reference_resource.num_npus=2 \ + 2>&1 | tee $PROJECT_PATH/logs/test_ref_update_case3.log +fi \ No newline at end of file diff --git a/tests/st/ref_update/test_ref_update.py b/tests/st/ref_update/test_ref_update.py new file mode 100644 index 00000000..0b7e332c --- /dev/null +++ b/tests/st/ref_update/test_ref_update.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd.2023-2025. All rights reserved. +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reversed. + +import sys +import logging +from datetime import timedelta +from tokenize import group +from unittest import TestCase +from unittest.mock import MagicMock +from venv import logger +import ray +import hydra +import torch +import yaml +from pathlib import Path + +from mindspeed_rl.config_cls.validate_config import validate_rl_args +from mindspeed_rl.models.rollout.vllm_engine import VLLMInferEngine + +from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.config_cls.megatron_config import MegatronConfig +from mindspeed_rl.config_cls.rl_config import RLConfig +from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig +from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.trainer.grpo_trainer_hybrid import RayGRPOTrainer +from mindspeed_rl.workers.scheduler.launcher import RayActorGroup +from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorkerBase +from mindspeed_rl.workers.reference_woker import ReferenceWorkerBase +from mindspeed_rl.utils.compute import get_parallel_state + +curr_file_dir = Path(__file__).absolute().parent.parent.parent.parent +sys.path.append(curr_file_dir) + +from cli.train_grpo import get_megatron_module, gpt_model_provider, initialize_megatron + +logger = Loggers("test_ref_updte", logger_level=logging.DEBUG) + +CHECK_M = 3 + + +def make_megatron_config(args): + + actor_config = MegatronConfig( + { + **args.get("megatron_training"), + **args.get("actor_config"), + }, + args.get("model"), + ) + + rl_config = RLConfig(args.get("rl_config")) + + ref_config = MegatronConfig( + { + **args.get("megatron_training"), + **args.get("ref_config"), + }, + args.get("model"), + ) + reward_config = MegatronConfig( + { + **args.get("megatron_training"), + }, + args.get("model"), + ) + + generate_config = GenerateConfig(args.get("generate_config")) + + validate_rl_args(actor_config, ref_config, reward_config, rl_config, generate_config) + + profiler_config = {} + + profiler_config.update( + { + "common": ProfilerConfig(args.get("profiler_config", {}).get(f"common", {}), role="common"), + } + ) + + msprobe_config = MsprobeConfig(args.get("msprobe_config", {}), role="common") + + return { + "actor_config": actor_config, + "ref_config": ref_config, + "reward_config": reward_config, + "rl_config": rl_config, + "generate_config": generate_config, + "profiler_config": profiler_config, + "msprobe_config": msprobe_config, + } + + +path_test_tmp = curr_file_dir.joinpath("tmp/test_ref_update") +path_test_tmp.mkdir(exist_ok=True, parents=True) +path_test_tmp_save = path_test_tmp.joinpath("saved_layers") +path_test_tmp_save.mkdir(exist_ok=True, parents=True) +path_test_tmp_save_ref = path_test_tmp_save.joinpath("ref") +path_test_tmp_save_actor = path_test_tmp_save.joinpath("actor") +path_test_tmp_save_ref.mkdir(exist_ok=True, parents=True) +path_test_tmp_save_actor.mkdir(exist_ok=True, parents=True) + + +def pre_ref_update(self): + + print(f"[pre_ref_update()] rank: {self.rank}, list_dp: {get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS}") + self.megatron_offloader.offload_param() + self.ref_params = dict(self.reference.model[0].named_parameters()) + + #! TEST: Modify the first n weights and perform verification. + count = 0 + for k, v in self.ref_params.items(): + v.add_(-v) + v.add_(torch.ones_like(v, device=v.device)) + torch.save(v.cpu(), path_test_tmp_save_ref.joinpath(f"pre_rank_{self._local_rank}_" + k + ".pt")) + count += 1 + if count == CHECK_M: + break + return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS + + +def post_ref_update(self): + from mindspeed_rl.utils.compute import get_parallel_state + + print("post_ref_update") + self.megatron_offloader.onload_param() + + if len(get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS) > 1: + rank_dp0 = get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS[0] + print(f"Start DP broadcast sync.") + for _, v in self.ref_params.items(): + torch.distributed.broadcast(v, src=rank_dp0, group=get_parallel_state()._DATA_PARALLEL_GROUP) + print(f"Finished DP broadcast sync.") + + else: + print(f"DP=1, ignore ref dp broadcast.") + + #! TEST: Modify the first n weights and perform verification. + count = 0 + for k, v in self.ref_params.items(): + torch.save(v.cpu(), path_test_tmp_save_ref.joinpath(f"post_rank_{self._local_rank}_" + k + ".pt")) + count += 1 + if count == CHECK_M: + break + #! + + +def sync_weight_to_ref(self, peer_handler: ray.actor.ActorHandle): + import pandas as pd + import math + + self.empty_cache() + + self.sharding_manager.megatron_offloader.onload_param() + + list_param = list(self.actor_hybrid.train_actor.model[0].named_parameters()) + + num_param = len(list_param) + dict_param_info = [ + { + "param_name": k, + "size": math.prod(v.shape) * v.dtype.itemsize // 1024, + "dype": str(v.dtype), + "shape": str(list(v.shape)), + } + for k, v in list_param + ] + df = pd.DataFrame(dict_param_info) + + print(f"Num of weight to sync: {len(df)}") + print(f"Size of weight to sync: {sum(df['size'])/1024} MB.") + + #! TEST: Modify the first n weights and perform verification. + count = 0 + for k, v in list_param: + v.add_(-v) + torch.save(v.cpu(), path_test_tmp_save_actor.joinpath(f"rank_{self._local_rank}_" + k + ".pt")) + count += 1 + if count == CHECK_M: + break + + peer_rank = ray.get(peer_handler.get_value.remote("_local_rank")) + for i, (key, val) in enumerate(list_param): + size = math.prod(val.shape) * val.dtype.itemsize / 1024 / 1024 + logger.debug(f"NPU: {self._local_rank} --> {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB, {key}") + + ray.get(peer_handler.receive_from_actor.remote(key, val.cpu())) + + logger.info(f"ActorHybridWorker (rank: {self._local_rank}) sync done") + self.sharding_manager.megatron_offloader.offload_param() + + +ReferenceWorkerBase.post_ref_update = post_ref_update +ReferenceWorkerBase.pre_ref_update = pre_ref_update + +ActorHybridWorkerBase.sync_weight_to_ref = sync_weight_to_ref + + +@ray.remote(resources={"NPU": 0.7}) +class ReferenceWorker(ReferenceWorkerBase): + pass + + +@ray.remote(resources={"NPU": 0.7}) +class ActorHybridWorker(ActorHybridWorkerBase): + pass + + +class TestRefUpdate: + """ + Test reference model update with ref-actor spliting mode. + """ + + def __init__(self, args) -> None: + actor_config, ref_config, _, rl_config, generate_config, profiler_config, msprobe_config = make_megatron_config(args).values() + tokenizer = MagicMock() + + self.actor_worker = RayActorGroup( + worker=ActorHybridWorker, + placement_group=None, + megatron_config=actor_config, + rl_config=rl_config, + generate_config=generate_config, + model_provider=gpt_model_provider, + profiler_config=profiler_config["common"], + msprobe_config=msprobe_config, + tokenizer=tokenizer, + initialize_func=initialize_megatron, + get_megatron_module=get_megatron_module, + global_batch_size=actor_config.global_batch_size * rl_config.n_samples_per_prompt, + ).initialize() + + self.reference_worker = RayActorGroup( + worker=ReferenceWorker, + placement_group=None, + megatron_config=ref_config, + rl_config=rl_config, + generate_config=generate_config, + model_provider=gpt_model_provider, + profiler_config=profiler_config["common"], + msprobe_config=msprobe_config, + tokenizer=tokenizer, + initialize_func=initialize_megatron, + get_megatron_module=get_megatron_module, + global_batch_size=actor_config.global_batch_size * rl_config.n_samples_per_prompt, + ).initialize() + + self.actor_worker.wait_all_ref_objs_run_over() + self.reference_worker.wait_all_ref_objs_run_over() + + self.trainer = RayGRPOTrainer( + self.actor_worker, + self.reference_worker, + reward_list=[], + tokenizer=tokenizer, + global_batch_size=actor_config.global_batch_size, + micro_batch_size=rl_config.adv_dispatch_size, + train_iters=actor_config.train_iters, + save_interval=actor_config.save_interval, + dataset_additional_keys=actor_config.dataset_additional_keys, + **rl_config.dict(), + ) + + def reference_update(self): + self.trainer.update_ref() + print(f"reference_update done.") + + +@hydra.main( + config_path="../configs/", + config_name="test_grpo_trainer_qwen25_7b_ref_update", + version_base=None, +) +def main(config): + if not ray.is_initialized(): + # this is for local ray luster + logger.info("start initializing local ray cluster") + rl_config = RLConfig(config.get("rl_config")) + with open(Path(curr_file_dir).joinpath(rl_config.runtime_env_path)) as file: + runtime_env = yaml.safe_load(file) + logger.info(f"ray init with runtime_env: {runtime_env}") + ray.init(runtime_env=runtime_env) + + test_case = TestRefUpdate(config) + test_case.reference_update() + + print("Reference sync test done.") + + +if __name__ == "__main__": + main() -- Gitee From cb079a368750c69060951608c4b2ae3ae4f1c199 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Fri, 6 Jun 2025 14:45:02 +0800 Subject: [PATCH 02/12] . --- mindspeed_rl/models/rollout/vllm_engine.py | 117 +++++++++++++-------- mindspeed_rl/workers/reference_woker.py | 95 +++++++++++------ 2 files changed, 133 insertions(+), 79 deletions(-) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 492996db..95da450b 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -32,7 +32,11 @@ from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.models.base.base_inference_engine import BaseInferEngine from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import initialize_parallel_state -from mindspeed_rl.models.rollout.vllm_adapter.megatron_weight_loaders import load_megatron_weights, update_megatron_weight_loader, InferParallelConfig +from mindspeed_rl.models.rollout.vllm_adapter.megatron_weight_loaders import ( + load_megatron_weights, + update_megatron_weight_loader, + InferParallelConfig +) from mindspeed_rl.utils import get_tokenizer logger = Loggers("vllm_engine") @@ -40,27 +44,27 @@ logger = Loggers("vllm_engine") class VLLMInferEngine(BaseInferEngine): def __init__( - self, - tokenizer_name_or_path: str, - train_tensor_parallel_size: int, - train_pipeline_parallel_size: int, - train_expert_parallel_size: int, - infer_tensor_parallel_size: int, - infer_pipeline_parallel_size: int, - infer_expert_parallel_size: int, - megatron_config: MegatronConfig, - sampling_config: dict, - prompt_type: str = None, - prompt_type_path: str = None, - enable_prefix_caching: bool = False, - num_scheduler_steps: int = 1, - max_num_seqs: int = 1, - max_model_len: int = 2048, - dtype: str = "bfloat16", - gpu_memory_utilization: float = 0.5, - trust_remote_code: bool = True, - load_format: str = "megatron", - **kwargs, + self, + tokenizer_name_or_path: str, + train_tensor_parallel_size: int, + train_pipeline_parallel_size: int, + train_expert_parallel_size: int, + infer_tensor_parallel_size: int, + infer_pipeline_parallel_size: int, + infer_expert_parallel_size: int, + megatron_config: MegatronConfig, + sampling_config: dict, + prompt_type: str = None, + prompt_type_path: str = None, + enable_prefix_caching: bool = False, + num_scheduler_steps: int = 1, + max_num_seqs: int = 1, + max_model_len: int = 2048, + dtype: str = "bfloat16", + gpu_memory_utilization: float = 0.5, + trust_remote_code: bool = True, + load_format: str = "megatron", + **kwargs ): """ Initialize the VLLM inference engine. @@ -98,31 +102,37 @@ class VLLMInferEngine(BaseInferEngine): max_model_len=max_model_len, dtype=dtype, gpu_memory_utilization=gpu_memory_utilization, - trust_remote_code=trust_remote_code, + trust_remote_code=trust_remote_code ) # Additional initialization logic for VLLMInferEngine # Initialize sampling parameters from SamplingConfig self.sampling_config = sampling_config try: self.sampling_params = SamplingParams( - n=sampling_config.get("num_completions", 1), - logprobs=sampling_config.get("logprobs", 1), - max_tokens=sampling_config.get("max_tokens", 128), - best_of=sampling_config.get("best_of", 2), - top_p=sampling_config.get("top_p", 1.0), - top_k=sampling_config.get("top_k", 50), - min_p=sampling_config.get("min_p", 0.0), - temperature=sampling_config.get("temperature", 0.2), - detokenize=sampling_config.get("detokenize", False), - seed=sampling_config.get("seed", None), + n=sampling_config.get('num_completions', 1), + logprobs=sampling_config.get('logprobs', 1), + max_tokens=sampling_config.get('max_tokens', 128), + best_of=sampling_config.get('best_of', 2), + top_p=sampling_config.get('top_p', 1.0), + top_k=sampling_config.get('top_k', 50), + min_p=sampling_config.get('min_p', 0.0), + temperature=sampling_config.get('temperature', 0.2), + detokenize=sampling_config.get('detokenize', False), + seed=sampling_config.get('seed', None) ) except Exception as e: raise ValueError(f"Error creating SamplingParams from dictionary") from e - self.hf_config = AutoConfig.from_pretrained(tokenizer_name_or_path, trust_remote_code=trust_remote_code) + self.hf_config = AutoConfig.from_pretrained( + tokenizer_name_or_path, + trust_remote_code=trust_remote_code + ) - self.tokenizer = get_tokenizer(tokenizer_name_or_path, prompt_type=prompt_type, prompt_type_path=prompt_type_path) - self.pad_token_id = self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None else self.tokenizer.tokenizer.eos_token_id + self.tokenizer = get_tokenizer(tokenizer_name_or_path, + prompt_type=prompt_type, prompt_type_path=prompt_type_path) + self.pad_token_id = ( + self.tokenizer.tokenizer.pad_token_id if self.tokenizer.tokenizer.pad_token_id is not None + else self.tokenizer.tokenizer.eos_token_id) # Set up local rank using the helper function self.local_rank = get_local_rank() @@ -130,8 +140,8 @@ class VLLMInferEngine(BaseInferEngine): # Initialize parallel state if tensor parallel size is specified if train_tensor_parallel_size is not None: num_tp_per_train_tp = train_tensor_parallel_size // infer_tensor_parallel_size - os.environ["CUDA_TIMER_STREAM_KAFKA_ENABLE"] = "0" - os.environ["MEGATRON_IMPORT_TIMERS"] = "0" + os.environ['CUDA_TIMER_STREAM_KAFKA_ENABLE'] = '0' + os.environ['MEGATRON_IMPORT_TIMERS'] = '0' initialize_parallel_state( infer_tensor_model_parallel_size=infer_tensor_parallel_size, train_tensor_model_parallel_size=train_tensor_parallel_size, @@ -159,7 +169,7 @@ class VLLMInferEngine(BaseInferEngine): skip_tokenizer_init=False, gpu_memory_utilization=gpu_memory_utilization, max_num_seqs=max_num_seqs, - max_model_len=max_model_len, + max_model_len=max_model_len ) self.model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() @@ -211,9 +221,13 @@ class VLLMInferEngine(BaseInferEngine): for name, params in self.model.named_parameters(): params.data = self.cpu_model[name] - def sync_model_weights(self, params, load_format="megatron"): - infer_parallel_config = InferParallelConfig(self.infer_tensor_parallel_size, self.infer_pipeline_parallel_size, self.infer_expert_parallel_size) - load_megatron_weights(params, self.model, infer_parallel_config, self.hf_config) + def sync_model_weights(self, params, load_format='megatron'): + infer_parallel_config = InferParallelConfig(self.infer_tensor_parallel_size, self.infer_pipeline_parallel_size, + self.infer_expert_parallel_size) + load_megatron_weights(params, + self.model, + infer_parallel_config, + self.hf_config) if hasattr(self.model.model.layers[0].self_attn, "mla_attn"): self._process_mla() @@ -224,11 +238,17 @@ class VLLMInferEngine(BaseInferEngine): mla.w_kc = None mla.w_vc = None + @mstx_timer_decorator def generate_sequences(self, idx_list, **kwargs): self.init_cache_engine() with self.update_sampling_params(**kwargs): - response = self.llm.generate(prompts=None, sampling_params=self.sampling_params, prompt_token_ids=idx_list, use_tqdm=False) + response = self.llm.generate( + prompts=None, + sampling_params=self.sampling_params, + prompt_token_ids=idx_list, + use_tqdm=False + ) outs = self._post_process_outputs(response) self.free_cache_engine() return outs @@ -249,9 +269,11 @@ class VLLMInferEngine(BaseInferEngine): logprob.append(logprobs_dict[token_id].logprob) logprobs.append(torch.tensor(logprob)) - output_token_ids = pad_sequence(output_token_ids, batch_first=True, padding_value=self.pad_token_id) + output_token_ids = pad_sequence(output_token_ids, batch_first=True, + padding_value=self.pad_token_id) if len(logprobs) > 0: - logprobs = pad_sequence(logprobs, batch_first=True, padding_value=self.pad_token_id) + logprobs = pad_sequence(logprobs, batch_first=True, + padding_value=self.pad_token_id) return output_token_ids, logprobs @contextmanager @@ -270,7 +292,10 @@ class VLLMInferEngine(BaseInferEngine): setattr(self.sampling_params, key, value) def chat(self, conversation, sampling_params=None): - outputs = self.llm.chat(conversation, sampling_params=sampling_params if sampling_params else self.sampling_params, use_tqdm=False) + outputs = self.llm.chat( + conversation, + sampling_params=sampling_params if sampling_params else self.sampling_params, + use_tqdm=False) return outputs diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 7c2b9bb5..2584840b 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -37,18 +37,25 @@ class ReferenceWorkerBase(BaseWorker): """ def __init__( - self, - megatron_config: MegatronConfig, - rl_config: RLConfig, - generate_config: GenerateConfig, - model_provider: Callable, - initialize_func: Callable, - tokenizer: BaseTokenizer = None, - get_megatron_module: Callable = None, - **kwargs + self, + megatron_config: MegatronConfig, + rl_config: RLConfig, + generate_config: GenerateConfig, + model_provider: Callable, + initialize_func: Callable, + tokenizer: BaseTokenizer = None, + get_megatron_module: Callable = None, + **kwargs ): super().__init__( - megatron_config, rl_config, generate_config, model_provider=model_provider, initialize_func=initialize_func, tokenizer=tokenizer, get_megatron_module=get_megatron_module, **kwargs + megatron_config, + rl_config, + generate_config, + model_provider=model_provider, + initialize_func=initialize_func, + tokenizer=tokenizer, + get_megatron_module=get_megatron_module, + **kwargs ) self.reference = None @@ -57,7 +64,8 @@ class ReferenceWorkerBase(BaseWorker): self.model = self.get_model(self.model_provider, self.model_type, wrap_with_ddp=False) if self.megatron_config.load is not None or self.megatron_config.pretrained_checkpoint is not None: - self.megatron_config.iteration, self.megatron_config.num_floating_point_operations_so_far = self.load_checkpoint(self.model, None, None) + self.megatron_config.iteration, self.megatron_config.num_floating_point_operations_so_far = self.load_checkpoint( + self.model, None, None) else: self.megatron_config.iteration = 0 self.megatron_config.num_floating_point_operations_so_far = 0 @@ -74,7 +82,7 @@ class ReferenceWorkerBase(BaseWorker): stage=self.megatron_config.stage, forward_backward_func=self.forward_backward_func, micro_batch_size=self.megatron_config.micro_batch_size, - temperature=self.generate_config.sampling_config["temperature"], + temperature=self.generate_config.sampling_config["temperature"] ) def init_transfer_dock(self, td): @@ -82,26 +90,32 @@ class ReferenceWorkerBase(BaseWorker): @mstx_timer_decorator def compute_ref_log_prob(self): - experience_consumer_stage = "ref_log_prob" - experience_columns = ["input_ids", "responses", "response_length", "prompt_length"] + experience_consumer_stage = 'ref_log_prob' + experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] experience_count = self.rl_config.ref_dispatch_size - sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None + sorted_indexes = self.get_dp_range_indexes(experience_count, + use_vllm=False) if self.rl_config.guarantee_order else None start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: - batch_data, index = self.dispatch_transfer_dock_data( - experience_consumer_stage, - experience_columns, - experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, - get_n_samples=False, - ) + batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, + experience_columns, + experience_count, + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, + get_n_samples=False) if not start_time_defined: start_time = time.time() start_time_defined = True - ray.get(self.td.update_metrics.remote("start_time/reference_model", value=[round(start_time, 4)], cumulate=True)) + ray.get( + self.td.update_metrics.remote( + "start_time/reference_model", + value=[round(start_time, 4)], + cumulate=True + ) + ) if batch_data and index: output, batch = self.reference.compute_log_prob(batch_data) @@ -109,23 +123,33 @@ class ReferenceWorkerBase(BaseWorker): # only on last rank. It should be on every tp rank log_probs = torch.cat(output, dim=0) # (bs, seq_size) log_probs = log_probs.to(torch.float32) - log_probs = truncate_rows(log_probs, batch["response_length"]) - output = {"ref_log_prob": log_probs} + log_probs = truncate_rows(log_probs, batch['response_length']) + output = {'ref_log_prob': log_probs} self.collect_transfer_dock_data(output, index) end_time = time.time() - ray.get(self.td.update_metrics.remote("timing/reference_model", value=[round(end_time, 4), round(start_time, 4)], cumulate=True)) + ray.get( + self.td.update_metrics.remote( + "timing/reference_model", + value=[round(end_time, 4), round(start_time, 4)], + cumulate=True + ) + ) parallel_state = get_parallel_state() use_vllm = False if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: ref_end_time = time.time() - ray.get(self.td.update_metrics.remote("end_time/reference", value=[round(ref_end_time, 4)])) + ray.get( + self.td.update_metrics.remote( + "end_time/reference", + value=[round(ref_end_time, 4)] + ) + ) logger.info("finish compute ref log prob") self.empty_cache() def save_ref_ckpt(self, iteration: int): from megatron.training import get_args - args = get_args() original_save_path = args.save @@ -135,26 +159,31 @@ class ReferenceWorkerBase(BaseWorker): args.save = original_save_path + @torch.no_grad() def pre_ref_update(self): + parallel_state = get_parallel_state() self.megatron_offloader.offload_param() self.ref_params = dict(self.reference.model[0].named_parameters()) - return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS + return parallel_state._DATA_PARALLEL_GLOBAL_RANKS + @torch.no_grad() def receive_from_actor(self, layer_name: str, tensor_weight: torch.Tensor): ref_key = layer_name.replace("module.module.", "module.") alpha = self.rl_config.ref_model_mixup_alpha self.ref_params[ref_key].mul_(1 - alpha).add_(tensor_weight, alpha=alpha) + @torch.no_grad() def post_ref_update(self): + parallel_state = get_parallel_state() self.megatron_offloader.onload_param() - if len(get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS) > 1: - rank_dp0 = get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS[0] + if len(parallel_state._DATA_PARALLEL_GLOBAL_RANKS) > 1: + rank_dp0 = parallel_state._DATA_PARALLEL_GLOBAL_RANKS[0] for _, v in self.ref_params.items(): - torch.distributed.broadcast(v, src=rank_dp0, group=get_parallel_state()._DATA_PARALLEL_GROUP) + torch.distributed.broadcast(v, src=rank_dp0, group=parallel_state._DATA_PARALLEL_GROUP) @ray.remote(resources={"NPU": 0.3}) -- Gitee From 8bbefb567992a6375b1942bc4729f5e5035dc2a9 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Fri, 6 Jun 2025 14:46:11 +0800 Subject: [PATCH 03/12] . . --- mindspeed_rl/models/rollout/vllm_engine.py | 2 +- mindspeed_rl/workers/reference_woker.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 95da450b..c3767000 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -238,7 +238,7 @@ class VLLMInferEngine(BaseInferEngine): mla.w_kc = None mla.w_vc = None - + @torch.no_grad() @mstx_timer_decorator def generate_sequences(self, idx_list, **kwargs): self.init_cache_engine() diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 2584840b..6fe00afa 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -11,7 +11,6 @@ from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.config_cls.rl_config import RLConfig from mindspeed_rl.config_cls.generate_config import GenerateConfig from mindspeed_rl.models.reference import Reference -from mindspeed_rl.trainer.utils import parallel_state from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.base_worker import BaseWorker @@ -161,12 +160,11 @@ class ReferenceWorkerBase(BaseWorker): @torch.no_grad() def pre_ref_update(self): - parallel_state = get_parallel_state() self.megatron_offloader.offload_param() self.ref_params = dict(self.reference.model[0].named_parameters()) - return parallel_state._DATA_PARALLEL_GLOBAL_RANKS + return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS @torch.no_grad() def receive_from_actor(self, layer_name: str, tensor_weight: torch.Tensor): @@ -177,13 +175,12 @@ class ReferenceWorkerBase(BaseWorker): @torch.no_grad() def post_ref_update(self): - parallel_state = get_parallel_state() self.megatron_offloader.onload_param() - if len(parallel_state._DATA_PARALLEL_GLOBAL_RANKS) > 1: - rank_dp0 = parallel_state._DATA_PARALLEL_GLOBAL_RANKS[0] + if len(get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS) > 1: + rank_dp0 = get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS[0] for _, v in self.ref_params.items(): - torch.distributed.broadcast(v, src=rank_dp0, group=parallel_state._DATA_PARALLEL_GROUP) + torch.distributed.broadcast(v, src=rank_dp0, group=get_parallel_state()._DATA_PARALLEL_GROUP) @ray.remote(resources={"NPU": 0.3}) -- Gitee From 5dd7958f8b174bd50cdbbd7f29b4080770371803 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Fri, 6 Jun 2025 14:51:41 +0800 Subject: [PATCH 04/12] Update reference_woker.py Update test_grpo_trainer_qwen25_7b_ref_split.yaml --- mindspeed_rl/workers/reference_woker.py | 3 --- .../st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml | 7 +++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 6fe00afa..fde9aa28 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -158,7 +158,6 @@ class ReferenceWorkerBase(BaseWorker): args.save = original_save_path - @torch.no_grad() def pre_ref_update(self): self.megatron_offloader.offload_param() @@ -166,14 +165,12 @@ class ReferenceWorkerBase(BaseWorker): return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS - @torch.no_grad() def receive_from_actor(self, layer_name: str, tensor_weight: torch.Tensor): ref_key = layer_name.replace("module.module.", "module.") alpha = self.rl_config.ref_model_mixup_alpha self.ref_params[ref_key].mul_(1 - alpha).add_(tensor_weight, alpha=alpha) - @torch.no_grad() def post_ref_update(self): self.megatron_offloader.onload_param() diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml index 5d6801be..b264d0f3 100644 --- a/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml @@ -31,12 +31,15 @@ megatron_training: dataset_additional_keys: ['labels',] data_path: /workspace/dataset/pe_nlp/pe_nlp split: 100,0,0 + no_shuffle: true + full_shuffle_instruction_dataset: false + seed: 1234 actor_config: model: qwen25_7b micro_batch_size: 4 - tensor_model_parallel_size: 4 - pipeline_model_parallel_size: 1 + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 2 lr: 5e-7 lr_decay_style: cosine min_lr: 5e-8 -- Gitee From fa273d24677a23ae179ae01e3a1963b190f704f7 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 11:41:04 +0800 Subject: [PATCH 05/12] . --- mindspeed_rl/workers/reference_woker.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index fde9aa28..6fe00afa 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -158,6 +158,7 @@ class ReferenceWorkerBase(BaseWorker): args.save = original_save_path + @torch.no_grad() def pre_ref_update(self): self.megatron_offloader.offload_param() @@ -165,12 +166,14 @@ class ReferenceWorkerBase(BaseWorker): return get_parallel_state()._DATA_PARALLEL_GLOBAL_RANKS + @torch.no_grad() def receive_from_actor(self, layer_name: str, tensor_weight: torch.Tensor): ref_key = layer_name.replace("module.module.", "module.") alpha = self.rl_config.ref_model_mixup_alpha self.ref_params[ref_key].mul_(1 - alpha).add_(tensor_weight, alpha=alpha) + @torch.no_grad() def post_ref_update(self): self.megatron_offloader.onload_param() -- Gitee From 3c91cb5f64332a8bfd9e1e05b52cb7f9d36d2dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=AB=E5=B2=9A=E5=BB=B6=E6=97=AD?= Date: Sat, 7 Jun 2025 04:03:38 +0000 Subject: [PATCH 06/12] add tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 枫岚延旭 --- .../test_full_grpo_qwen25_32b_ref_update.sh | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh diff --git a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh new file mode 100644 index 00000000..193438e2 --- /dev/null +++ b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh @@ -0,0 +1,106 @@ +# pkill -9 python +ray stop --force +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HCCL_DETERMINISTIC=True + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +export PYTHONPATH=$SCRIPT_DIR/../:$PYTHONPATH +PROJECT_PATH=$SCRIPT_DIR/../../.. +PROFILER_DATA_PATH=$PROJECT_PATH/ci/profiler_data +rm -rf $PROFILER_DATA_PATH + +mkdir -p $PROJECT_PATH/logs + +export GLOO_SOCKET_IFNAME=ens3f0 +export TP_SOCKET_IFNAME=ens3f0 +export HCCL_SOCKET_IFNAME=ens3f0 + +export HCCL_INTRA_ROCE_ENABLE=1 +# export RAY_DEBUG_POST_MORTEM=1 +export TASK_QUEUE_ENABLE=2 +export HCCL_IF_BASE_PORT=24703 + +tp=4 +pp=2 +load_ckpt=/workspace/weight/Qwen2.5-32B-Instruct-tp4pp2 + +YAML=test_grpo_trainer_qwen25_32b_ref_split + +NNODES=2 +NPUS_PER_NODE=16 +#修改为对应主节点IP +MASTER_ADDR="141.61.41.163" +#获取当前机器IP +CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') +echo "CURRENT_IP: $CURRENT_IP" + +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # 主节点启动 + ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # 判断 device_count 是否与 NNODES 相等 + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + + python $PROJECT_PATH/cli/train_grpo.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=$YAML \ + megatron_training.model=qwen25_32b \ + actor_config.tensor_model_parallel_size=$tp \ + actor_config.pipeline_model_parallel_size=$pp \ + actor_config.load=$load_ckpt \ + ref_config.tensor_model_parallel_size=$tp \ + ref_config.pipeline_model_parallel_size=$pp \ + ref_config.load=$load_ckpt \ + ref_config.model=qwen25_32b \ + ref_config.micro_batch_size=2 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-32B-Instruct \ + megatron_training.global_batch_size=32 \ + megatron_training.no_shuffle=true \ + megatron_training.full_shuffle_instruction_dataset=false \ + megatron_training.seed=1234 \ + actor_config.model=qwen25_32b \ + actor_config.micro_batch_size=2 \ + rl_config.actor_resource.num_npus=16 \ + rl_config.reference_resource.num_npus=16 \ + rl_config.mini_batch_size=32 \ + megatron_training.train_iters=500 \ + rl_config.ref_save=/workspace/weight/ref_save \ + actor_config.save=/workspace/weight/actor_save \ + megatron_training.save_interval=20 \ + rl_config.ref_model_sync_steps=1 \ + megatron_training.global_batch_size=32 \ + generate_config.infer_tensor_parallel_size=8 \ + generate_config.max_num_seqs=1024 \ + 2>&1 | tee $PROJECT_PATH/logs/test_Qwen32B_32p_update_ref_full.log + + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # 子节点尝试往主节点注册ray直到成功 + while true; do + # 尝试连接 Ray 集群 + ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # 检查连接是否成功 + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done +fi \ No newline at end of file -- Gitee From 7ab8f16f05be02d7b6d9a3b3f92957d089531e75 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 12:04:48 +0800 Subject: [PATCH 07/12] . Update actor_hybrid_worker.py . . --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 1 + mindspeed_rl/workers/actor_hybrid_worker.py | 2 +- ...test_grpo_trainer_qwen25_7b_ref_split.yaml | 1 + .../test_full_grpo_qwen25_32b_ref_update.sh | 210 +++++++++--------- 4 files changed, 108 insertions(+), 106 deletions(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index d980930b..99d73dd7 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -254,6 +254,7 @@ class RayGRPOTrainer(RayBaseTrainer): ref_sync_objs.append(ref.pre_ref_update.remote()) list_data_parallel_global_ranks = ray.get(ref_sync_objs) list_dp0_rank = [ranks[0] for ranks in list_data_parallel_global_ranks] + list_dp0_rank = list(set(list_dp0_rank)) #* Step 3 offer download command with Timer(name="sync ref", logger=None) as timer: diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 6b08c6be..d0ca942b 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -369,7 +369,7 @@ class ActorHybridWorkerBase(BaseWorker): peer_rank = ray.get(peer_handler.get_value.remote("_local_rank")) for i, (key, val) in enumerate(list_param): size = math.prod(val.shape) * val.dtype.itemsize / 1024 / 1024 - logger.info(f"NPU: {self._local_rank} --> {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB", key) + logger.debug(f"NPU: {self._local_rank} --> {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB", key) ray.get(peer_handler.receive_from_actor.remote(key, val.cpu())) diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml index b264d0f3..526b4a7d 100644 --- a/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_ref_split.yaml @@ -1,6 +1,7 @@ defaults: - model: - qwen25_7b + - qwen25_32b megatron_training: model: qwen25_7b diff --git a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh index 193438e2..f32b17a1 100644 --- a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh +++ b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh @@ -1,106 +1,106 @@ -# pkill -9 python -ray stop --force -export RAY_DEDUP_LOGS=0 -export HYDRA_FULL_ERROR=1 -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export HCCL_DETERMINISTIC=True - -SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) -export PYTHONPATH=$SCRIPT_DIR/../:$PYTHONPATH -PROJECT_PATH=$SCRIPT_DIR/../../.. -PROFILER_DATA_PATH=$PROJECT_PATH/ci/profiler_data -rm -rf $PROFILER_DATA_PATH - -mkdir -p $PROJECT_PATH/logs - -export GLOO_SOCKET_IFNAME=ens3f0 -export TP_SOCKET_IFNAME=ens3f0 -export HCCL_SOCKET_IFNAME=ens3f0 - -export HCCL_INTRA_ROCE_ENABLE=1 -# export RAY_DEBUG_POST_MORTEM=1 -export TASK_QUEUE_ENABLE=2 -export HCCL_IF_BASE_PORT=24703 - -tp=4 -pp=2 -load_ckpt=/workspace/weight/Qwen2.5-32B-Instruct-tp4pp2 - -YAML=test_grpo_trainer_qwen25_32b_ref_split - -NNODES=2 -NPUS_PER_NODE=16 -#修改为对应主节点IP -MASTER_ADDR="141.61.41.163" -#获取当前机器IP -CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') -echo "CURRENT_IP: $CURRENT_IP" - -if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then - # 主节点启动 - ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' - - while true; do - ray_status_output=$(ray status) - npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) - npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') - device_count=$((npu_count_int / $NPUS_PER_NODE)) - - # 判断 device_count 是否与 NNODES 相等 - if [ "$device_count" -eq "$NNODES" ]; then - echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." - ray status - - python $PROJECT_PATH/cli/train_grpo.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=$YAML \ - megatron_training.model=qwen25_32b \ - actor_config.tensor_model_parallel_size=$tp \ - actor_config.pipeline_model_parallel_size=$pp \ - actor_config.load=$load_ckpt \ - ref_config.tensor_model_parallel_size=$tp \ - ref_config.pipeline_model_parallel_size=$pp \ - ref_config.load=$load_ckpt \ - ref_config.model=qwen25_32b \ - ref_config.micro_batch_size=2 \ - megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-32B-Instruct \ - megatron_training.global_batch_size=32 \ - megatron_training.no_shuffle=true \ - megatron_training.full_shuffle_instruction_dataset=false \ - megatron_training.seed=1234 \ - actor_config.model=qwen25_32b \ - actor_config.micro_batch_size=2 \ - rl_config.actor_resource.num_npus=16 \ - rl_config.reference_resource.num_npus=16 \ - rl_config.mini_batch_size=32 \ - megatron_training.train_iters=500 \ - rl_config.ref_save=/workspace/weight/ref_save \ - actor_config.save=/workspace/weight/actor_save \ - megatron_training.save_interval=20 \ - rl_config.ref_model_sync_steps=1 \ - megatron_training.global_batch_size=32 \ - generate_config.infer_tensor_parallel_size=8 \ - generate_config.max_num_seqs=1024 \ - 2>&1 | tee $PROJECT_PATH/logs/test_Qwen32B_32p_update_ref_full.log - - break - else - echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" - sleep 5 - fi - done -else - # 子节点尝试往主节点注册ray直到成功 - while true; do - # 尝试连接 Ray 集群 - ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP - - # 检查连接是否成功 - ray status - if [ $? -eq 0 ]; then - echo "Successfully connected to the Ray cluster!" - break - else - echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." - sleep 5 - fi - done +# pkill -9 python +ray stop --force +export RAY_DEDUP_LOGS=0 +export HYDRA_FULL_ERROR=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export HCCL_DETERMINISTIC=True + +SCRIPT_DIR=$(cd "$(dirname "$0")" && pwd) +export PYTHONPATH=$SCRIPT_DIR/../:$PYTHONPATH +PROJECT_PATH=$SCRIPT_DIR/../../.. +PROFILER_DATA_PATH=$PROJECT_PATH/ci/profiler_data +rm -rf $PROFILER_DATA_PATH + +mkdir -p $PROJECT_PATH/logs + +export GLOO_SOCKET_IFNAME=ens3f0 +export TP_SOCKET_IFNAME=ens3f0 +export HCCL_SOCKET_IFNAME=ens3f0 + +export HCCL_INTRA_ROCE_ENABLE=1 +# export RAY_DEBUG_POST_MORTEM=1 +export TASK_QUEUE_ENABLE=2 +export HCCL_IF_BASE_PORT=24703 + +tp=4 +pp=2 +load_ckpt=/workspace/weight/Qwen2.5-32B-Instruct-tp4pp2 + +YAML=test_grpo_trainer_qwen25_32b_ref_split + +NNODES=2 +NPUS_PER_NODE=16 +#修改为对应主节点IP +MASTER_ADDR="141.61.41.163" +#获取当前机器IP +CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') +echo "CURRENT_IP: $CURRENT_IP" + +if [ "$MASTER_ADDR" = "$CURRENT_IP" ]; then + # 主节点启动 + ray start --head --port 6766 --dashboard-host=$MASTER_ADDR --node-ip-address=$CURRENT_IP --dashboard-port=8260 --resources='{"NPU": '$NPUS_PER_NODE'}' + + while true; do + ray_status_output=$(ray status) + npu_count=$(echo "$ray_status_output" | grep -oP '(?<=/)\d+\.\d+(?=\s*NPU)' | head -n 1) + npu_count_int=$(echo "$npu_count" | awk '{print int($1)}') + device_count=$((npu_count_int / $NPUS_PER_NODE)) + + # 判断 device_count 是否与 NNODES 相等 + if [ "$device_count" -eq "$NNODES" ]; then + echo "Ray cluster is ready with $device_count devices (from $npu_count NPU resources), starting Python script." + ray status + + python $PROJECT_PATH/cli/train_grpo.py --config-dir="$PROJECT_PATH"/tests/st/configs --config-name=$YAML \ + megatron_training.model=qwen25_32b \ + actor_config.tensor_model_parallel_size=$tp \ + actor_config.pipeline_model_parallel_size=$pp \ + actor_config.load=$load_ckpt \ + ref_config.tensor_model_parallel_size=$tp \ + ref_config.pipeline_model_parallel_size=$pp \ + ref_config.load=$load_ckpt \ + ref_config.model=qwen25_32b \ + ref_config.micro_batch_size=2 \ + megatron_training.tokenizer_name_or_path=/workspace/weight/Qwen2.5-32B-Instruct \ + megatron_training.global_batch_size=32 \ + megatron_training.no_shuffle=true \ + megatron_training.full_shuffle_instruction_dataset=false \ + megatron_training.seed=1234 \ + actor_config.model=qwen25_32b \ + actor_config.micro_batch_size=2 \ + rl_config.actor_resource.num_npus=16 \ + rl_config.reference_resource.num_npus=16 \ + rl_config.mini_batch_size=32 \ + megatron_training.train_iters=500 \ + rl_config.ref_save=/workspace/weight/ref_save \ + actor_config.save=/workspace/weight/actor_save \ + megatron_training.save_interval=20 \ + rl_config.ref_model_sync_steps=1 \ + megatron_training.global_batch_size=32 \ + generate_config.infer_tensor_parallel_size=8 \ + generate_config.max_num_seqs=1024 \ + 2>&1 | tee $PROJECT_PATH/logs/test_Qwen32B_32p_update_ref_full.log + + break + else + echo "Waiting for Ray to allocate $NNODES devices. Current device count: $device_count" + sleep 5 + fi + done +else + # 子节点尝试往主节点注册ray直到成功 + while true; do + # 尝试连接 Ray 集群 + ray start --address="$MASTER_ADDR:6766" --resources='{"NPU": '$NPUS_PER_NODE'}' --node-ip-address=$CURRENT_IP + + # 检查连接是否成功 + ray status + if [ $? -eq 0 ]; then + echo "Successfully connected to the Ray cluster!" + break + else + echo "Failed to connect to the Ray cluster. Retrying in 5 seconds..." + sleep 5 + fi + done fi \ No newline at end of file -- Gitee From 1bdb1e7b1550d1f661870787aed47b24ab34ec0c Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 14:45:51 +0800 Subject: [PATCH 08/12] . --- tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh index f32b17a1..f26ead3c 100644 --- a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh +++ b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh @@ -26,7 +26,7 @@ tp=4 pp=2 load_ckpt=/workspace/weight/Qwen2.5-32B-Instruct-tp4pp2 -YAML=test_grpo_trainer_qwen25_32b_ref_split +YAML=test_grpo_trainer_qwen25_7b_ref_split #* modified in cmd NNODES=2 NPUS_PER_NODE=16 -- Gitee From a1ca1812ab3c0427c8cc560d642f7334c585d6c4 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 14:46:46 +0800 Subject: [PATCH 09/12] Update grpo_trainer_hybrid.py --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index 99d73dd7..ca99986f 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -254,7 +254,7 @@ class RayGRPOTrainer(RayBaseTrainer): ref_sync_objs.append(ref.pre_ref_update.remote()) list_data_parallel_global_ranks = ray.get(ref_sync_objs) list_dp0_rank = [ranks[0] for ranks in list_data_parallel_global_ranks] - list_dp0_rank = list(set(list_dp0_rank)) + list_dp0_rank = sorted(list(set(list_dp0_rank))) #* Step 3 offer download command with Timer(name="sync ref", logger=None) as timer: -- Gitee From 6aed9bb9e39d439af89b359a61b4e04c634d7a3b Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 14:57:16 +0800 Subject: [PATCH 10/12] Update test_full_grpo_qwen25_32b_ref_update.sh --- tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh index f26ead3c..0edfb3b2 100644 --- a/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh +++ b/tests/st/ref_update/test_full_grpo_qwen25_32b_ref_update.sh @@ -13,10 +13,6 @@ rm -rf $PROFILER_DATA_PATH mkdir -p $PROJECT_PATH/logs -export GLOO_SOCKET_IFNAME=ens3f0 -export TP_SOCKET_IFNAME=ens3f0 -export HCCL_SOCKET_IFNAME=ens3f0 - export HCCL_INTRA_ROCE_ENABLE=1 # export RAY_DEBUG_POST_MORTEM=1 export TASK_QUEUE_ENABLE=2 @@ -31,7 +27,7 @@ YAML=test_grpo_trainer_qwen25_7b_ref_split #* modified in cmd NNODES=2 NPUS_PER_NODE=16 #修改为对应主节点IP -MASTER_ADDR="141.61.41.163" +MASTER_ADDR="host ip" #获取当前机器IP CURRENT_IP=$(ip -4 addr show $(ip -o -4 route show to default | awk '{print $5}') | grep -oP '(?<=inet\s)\d+(\.\d+){3}') echo "CURRENT_IP: $CURRENT_IP" -- Gitee From c834587e5fe7798ff2c3de9c682e55ff6b9a9c5d Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 15:11:12 +0800 Subject: [PATCH 11/12] Update actor_hybrid_worker.py --- mindspeed_rl/workers/actor_hybrid_worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index d0ca942b..67cd1c01 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -358,7 +358,7 @@ class ActorHybridWorkerBase(BaseWorker): import logging import math - logger = Loggers(f"sync_weight_to_ref_rank", logger_level=logging.DEBUG) + logger = Loggers(f"sync_weight_to_ref_rank") self.empty_cache() self.sharding_manager.megatron_offloader.onload_param() @@ -367,9 +367,12 @@ class ActorHybridWorkerBase(BaseWorker): num_param = len(list_param) peer_rank = ray.get(peer_handler.get_value.remote("_local_rank")) + peer_ip = ray.get(peer_handler._get_current_node_ip.remote()) + local_ip = self._get_current_node_ip() + for i, (key, val) in enumerate(list_param): size = math.prod(val.shape) * val.dtype.itemsize / 1024 / 1024 - logger.debug(f"NPU: {self._local_rank} --> {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB", key) + logger.debug(f"{local_ip} NPU: {self._local_rank} --> {peer_ip} NPU: {peer_rank}, ({i:-4d} / {num_param:-4d}) {size:.2f} MB", key) ray.get(peer_handler.receive_from_actor.remote(key, val.cpu())) -- Gitee From 10c72b3b0a6672517e6b097f29661398e0b4d726 Mon Sep 17 00:00:00 2001 From: WangLingfeng Date: Sat, 7 Jun 2025 15:46:22 +0800 Subject: [PATCH 12/12] Update test_ref_update.py --- tests/st/ref_update/test_ref_update.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/st/ref_update/test_ref_update.py b/tests/st/ref_update/test_ref_update.py index 0b7e332c..3b1c8b05 100644 --- a/tests/st/ref_update/test_ref_update.py +++ b/tests/st/ref_update/test_ref_update.py @@ -5,16 +5,13 @@ import sys import logging -from datetime import timedelta -from tokenize import group -from unittest import TestCase from unittest.mock import MagicMock +from pathlib import Path from venv import logger import ray import hydra import torch import yaml -from pathlib import Path from mindspeed_rl.config_cls.validate_config import validate_rl_args from mindspeed_rl.models.rollout.vllm_engine import VLLMInferEngine -- Gitee