From d585ff41af48b0d0e3dd7901e0b7c906b163b071 Mon Sep 17 00:00:00 2001 From: A1waysBeenHere Date: Thu, 19 Jun 2025 14:53:51 +0800 Subject: [PATCH 1/5] Support Qwen3 --- cli/train_grpo.py | 4 +- configs/grpo_qwen3_8b_A3.yaml | 109 ++++++++++++++++++ configs/model/qwen3_8b.yaml | 20 ++++ .../vllm_adapter/megatron_weight_loaders.py | 1 + .../workers/resharding/weight_adaptor.py | 45 ++++++++ 5 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 configs/grpo_qwen3_8b_A3.yaml create mode 100644 configs/model/qwen3_8b.yaml diff --git a/cli/train_grpo.py b/cli/train_grpo.py index 0994fe14..04109a57 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -318,7 +318,7 @@ def gpt_model_provider(pre_process, post_process): if args.spec is not None: transformer_layer_spec = import_module(args.spec) else: - transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, qk_layernorm=args.qk_layernorm) model = GPTModel( config=config, @@ -363,7 +363,7 @@ def rm_model_provider(pre_process, post_process): if args.spec is not None: transformer_layer_spec = import_module(args.spec) else: - transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, qk_layernorm=args.qk_layernorm) if (not args.untie_embeddings_and_output_weights) and (args.pipeline_model_parallel_size > 1): args.untie_embeddings_and_output_weights = True diff --git a/configs/grpo_qwen3_8b_A3.yaml b/configs/grpo_qwen3_8b_A3.yaml new file mode 100644 index 00000000..af2781fa --- /dev/null +++ b/configs/grpo_qwen3_8b_A3.yaml @@ -0,0 +1,109 @@ +defaults: + - model: + - qwen3_8b + +megatron_training: + model: qwen3_8b + use_fused_rmsnorm: true + use_mcore_models: true + sequence_parallel: true + use_flash_attn: true + no_masked_softmax_fusion: true + attention_softmax_in_fp32: true + no_gradient_accumulation_fusion: true + use_fused_swiglu: true + use_fused_rotary_pos_emb: true + bf16: true + use_distributed_optimizer: true + tokenizer_type: PretrainedFromHF + tokenizer_name_or_path: ./Qwen3-8B + global_batch_size: 32 + seq_length: 2048 + save_interval: 100 + train_iters: 500 + stage: ray_grpo + attention_dropout: 0.0 + init_method_std: 0.01 + hidden_dropout: 0.0 + distributed_backend: nccl + no_shared_storage: true + dataset_additional_keys: ['labels',] + data_path: ./dataset/deepscaler/data + split: 100,0,0 + no_shuffle: false + full_shuffle_instruction_dataset: false + +actor_config: + model: qwen3_8b + micro_batch_size: 1 + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 2 + lr: 1e-6 + lr_decay_style: constant + min_lr: 0 + weight_decay: 0.01 + lr_warmup_fraction: 0.0 + clip_grad: 1.0 + adam_beta1: 0.9 + adam_beta2: 0.95 + finetune: true + load: ./ckpt + save: ./ckpt + no_load_optim: true + no_load_rng: true + +rl_config: + guarantee_order: false + use_integrated_worker: true + blocking: true + gamma: 1.0 + lam: 0.95 + actor_forward_micro_batch_size: 1 + ref_forward_micro_batch_size: 1 + adv_estimator: group_norm + kl_penalty: low_var_kl + kl_ctrl_type: fixed + init_kl_coef: 0.001 + mini_batch_size: 32 + max_prompt_length: 2048 + epochs: 1 + clip_ratio: 0.2 + entropy_coeff: 0.0 + shuffle_minibatch: false + n_samples_per_prompt: 8 + rule_reward: true + verifier_function: ["base_acc"] + verifier_weight: [1.0] + num_cpus_for_local_task: 1.0 + use_tensorboard: true + actor_resource: + num_npus: 16 + +generate_config: + enforce_eager: True + trust_remote_code: true + offload_train_optimizer: true + offload_train_grad: true + offload_train_param: true + + # 推理时的并行配置 + infer_tensor_parallel_size: 1 + infer_pipeline_parallel_size: 1 + infer_expert_parallel_size: 1 + + # vllm 模型相关设置 + max_num_seqs: 1024 + max_model_len: 10240 + max_num_batched_tokens: 10240 + dtype: "bfloat16" + gpu_memory_utilization: 0.4 + + # 采样配置 + sampling_config: + logprobs: 1 + max_tokens: 8192 + top_p: 1.0 + top_k: -1 + min_p: 0.0 + temperature: 1.0 + detokenize: false \ No newline at end of file diff --git a/configs/model/qwen3_8b.yaml b/configs/model/qwen3_8b.yaml new file mode 100644 index 00000000..7a0c3b9b --- /dev/null +++ b/configs/model/qwen3_8b.yaml @@ -0,0 +1,20 @@ +qwen3_8b: + use_mcore_models: true + num_layers: 36 + hidden_size: 4096 + ffn_hidden_size: 12288 + num_attention_heads: 32 + rotary_base: 1000000 + max_position_embeddings: 40960 + make_vocab_size_divisible_by: 1 + padded_vocab_size: 151936 + untie_embeddings_and_output_weights: true + disable_bias_linear: true + group_query_attention: true + num_query_groups: 8 + position_embedding_type: rope + normalization: RMSNorm + swiglu: true + attention_softmax_in_fp32: true + attention_bias: false + qk_layernorm: true diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 5ace26b0..a5b5e7f0 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -218,6 +218,7 @@ def parallel_weight_loader(self, param: torch.Tensor, loaded_weight: torch.Tenso MODEL_MEGATRON_WEIGHT_LOADER_REGISTRY = { "LlamaForCausalLM": llama_megatron_core_weight_loader, "Qwen2ForCausalLM": qwen_megatron_weight_loader, + "Qwen3ForCausalLM": qwen_megatron_weight_loader, "DeepseekV3ForCausalLM": deepseek_megatron_weight_loader, "DeepseekV2ForCausalLM": deepseek_megatron_weight_loader, "CustomDeepseekV3ForCausalLM": deepseek_megatron_weight_loader, diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index 16118f2b..5d06eeb7 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -293,6 +293,50 @@ class QwenMVWeightAdaptor(MegatronVLLMWeightAdaptor): super(QwenMVWeightAdaptor, self).__init__(model_config) + +class Qwen3MvWeightAdaptor(MegatronVLLMWeightAdaptor): + def __init__(self, model_config): + super(MegatronVLLMWeightAdaptor, self).__init__() + self.model_config = model_config + self.meta_info = None + self.params_mapping = [ + # (megatron core gpt model name, vllm model name) + ("embedding.word_embeddings", "model.embed_tokens"), + ("self_attention.linear_qkv", "self_attn.qkv_proj"), + ("self_attention.linear_proj", "self_attn.o_proj"), + ("input_layernorm", "input_layernorm"), + ("pre_mlp_layernorm", "post_attention_layernorm"), + ("mlp.linear_fc1", "mlp.gate_up_proj"), + ("mlp.linear_fc2", "mlp.down_proj"), + ("decoder.final_layernorm", "model.norm"), + ("output_layer", "lm_head"), + ("self_attention.q_layernorm", "self_attn.q_norm") + ("self_attention.k_layernorm", "self_attn.k_norm") + ] + + def replace_name_i2t(self, inference_name): + """ + transfer inference weight name to training weight name + """ + for m_name, v_name in self.params_mapping: + if v_name not in inference_name: + continue + if "layers" in inference_name: # deal with decoder layers + inference_name = inference_name.replace("model", "decoder") + vllm_name_list = inference_name.split(".") + param_name_list = vllm_name_list[:3] + weight_or_bias = vllm_name_list[-1] + param_name_list.append(m_name) + if weight_or_bias in ['weight', 'bias']: + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + return param_name + else: + param_name = inference_name.replace(v_name, m_name) + return param_name + + + class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): def __init__(self, model_config): super(Qwen2_5_VLWeightAdaptor, self).__init__(model_config) @@ -582,6 +626,7 @@ class Qwen2_5_VLWeightAdaptor(MegatronVLLMWeightAdaptor): WEIGHT_ADAPTOR_REGISTRY = { "Qwen2ForCausalLM": QwenMVWeightAdaptor, + "Qwen3ForCausalLM": Qwen3MvWeightAdaptor, "DeepseekV3ForCausalLM": DeepSeekMVWeightAdaptor, "DeepseekV2ForCausalLM": DeepSeekMVWeightAdaptor, "CustomDeepseekV3ForCausalLM": DeepSeekMVWeightAdaptor, -- Gitee From 4b0af6de80ccce300fabbb54eca11639f6b9d7bc Mon Sep 17 00:00:00 2001 From: A1waysBeenHere Date: Thu, 19 Jun 2025 16:25:26 +0800 Subject: [PATCH 2/5] Add seed when init LLM engine --- mindspeed_rl/models/rollout/vllm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 65625d64..e4d37587 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -179,6 +179,7 @@ class VLLMInferEngine(BaseInferEngine): gpu_memory_utilization=gpu_memory_utilization, max_num_seqs=max_num_seqs, max_model_len=max_model_len, + seed=int(os.getenv("RANK", "0")) // infer_tensor_parallel_size additional_config={ 'expert_tensor_parallel_size': infer_expert_tensor_parallel_size, 'enable_graph_mode': int(os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')), -- Gitee From bc936db26a7b4662f10458e60b352ff62e8d5769 Mon Sep 17 00:00:00 2001 From: A1waysBeenHere Date: Thu, 19 Jun 2025 16:26:12 +0800 Subject: [PATCH 3/5] Add seed when init LLM engine --- mindspeed_rl/models/rollout/vllm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index e4d37587..90535c30 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -179,7 +179,7 @@ class VLLMInferEngine(BaseInferEngine): gpu_memory_utilization=gpu_memory_utilization, max_num_seqs=max_num_seqs, max_model_len=max_model_len, - seed=int(os.getenv("RANK", "0")) // infer_tensor_parallel_size + seed=int(os.getenv("RANK", "0")) // infer_tensor_parallel_size, additional_config={ 'expert_tensor_parallel_size': infer_expert_tensor_parallel_size, 'enable_graph_mode': int(os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')), -- Gitee From 23109958a25b462fa90a4c63a2f60e93593fe4f7 Mon Sep 17 00:00:00 2001 From: A1waysBeenHere Date: Thu, 19 Jun 2025 16:53:49 +0800 Subject: [PATCH 4/5] Bug fix --- mindspeed_rl/workers/resharding/weight_adaptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index 5d06eeb7..300eebeb 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -310,7 +310,7 @@ class Qwen3MvWeightAdaptor(MegatronVLLMWeightAdaptor): ("mlp.linear_fc2", "mlp.down_proj"), ("decoder.final_layernorm", "model.norm"), ("output_layer", "lm_head"), - ("self_attention.q_layernorm", "self_attn.q_norm") + ("self_attention.q_layernorm", "self_attn.q_norm"), ("self_attention.k_layernorm", "self_attn.k_norm") ] -- Gitee From c74411c4f9cc6b457ea92036a22bfef0d06de704 Mon Sep 17 00:00:00 2001 From: A1waysBeenHere Date: Mon, 23 Jun 2025 16:28:05 +0800 Subject: [PATCH 5/5] confict fix --- mindspeed_rl/models/rollout/vllm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 90535c30..3aac416f 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -179,7 +179,7 @@ class VLLMInferEngine(BaseInferEngine): gpu_memory_utilization=gpu_memory_utilization, max_num_seqs=max_num_seqs, max_model_len=max_model_len, - seed=int(os.getenv("RANK", "0")) // infer_tensor_parallel_size, + seed=self.sampling_params.seed, additional_config={ 'expert_tensor_parallel_size': infer_expert_tensor_parallel_size, 'enable_graph_mode': int(os.environ.get('VLLM_ENABLE_GRAPH_MODE', '0')), -- Gitee