From ba147bf09b3a0720226efbd21778052e684bca17 Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Tue, 5 Aug 2025 12:02:14 +0800 Subject: [PATCH 1/6] support mtp --- cli/train_dapo.py | 19 +++- cli/train_grpo.py | 19 +++- cli/train_ppo.py | 19 +++- mindspeed_rl/config_cls/generate_config.py | 2 + mindspeed_rl/config_cls/megatron_config.py | 3 + mindspeed_rl/models/actor_rollout_hybrid.py | 2 +- mindspeed_rl/models/base/__init__.py | 106 ++++++++++++++++++ .../models/base/base_inference_engine.py | 4 +- .../models/base/base_training_engine.py | 19 +++- .../vllm_adapter/megatron_weight_loaders.py | 4 +- mindspeed_rl/models/rollout/vllm_engine.py | 54 ++++++--- mindspeed_rl/utils/utils.py | 27 +++++ mindspeed_rl/workers/actor_hybrid_worker.py | 16 ++- .../resharding/vllm_weight_container.py | 2 + .../workers/resharding/weight_adaptor.py | 52 ++++++++- 15 files changed, 301 insertions(+), 47 deletions(-) diff --git a/cli/train_dapo.py b/cli/train_dapo.py index 2d78c0a0..25300195 100644 --- a/cli/train_dapo.py +++ b/cli/train_dapo.py @@ -17,7 +17,7 @@ from ray.util import placement_group from mindspeed_rl.utils import seed_all from mindspeed_rl.utils import get_tokenizer -from mindspeed_rl.utils.utils import MsProbe, get_node_nums +from mindspeed_rl.utils.utils import MsProbe, get_node_nums, print_config from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.utils import parse_args_from_config, init_torch_compile from mindspeed_rl.config_cls.validate_config import validate_rl_args @@ -223,6 +223,9 @@ def parse_training_config(config: Dict): role="integrated" ) + print_config(rl_config) + print_config(generate_config) + return { "actor_config": actor_config, "ref_config": ref_config, @@ -293,13 +296,18 @@ def gpt_model_provider(pre_process, post_process): Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + Union[GPTModel]: The returned model """ from megatron.training import get_args from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.spec_utils import import_module from megatron.training.arguments import core_transformer_config_from_args + from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + from mindspeed_rl.models.base import gpt_model_forward + + GPTModel.forward = gpt_model_forward + args = get_args() logger.info('building GPT model ...') @@ -310,7 +318,9 @@ def gpt_model_provider(pre_process, post_process): transformer_layer_spec = import_module(args.spec) else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, qk_layernorm=args.qk_layernorm) - + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=False) model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, @@ -323,7 +333,8 @@ def gpt_model_provider(pre_process, post_process): share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, ) return model diff --git a/cli/train_grpo.py b/cli/train_grpo.py index f2633eff..47bc2334 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -19,7 +19,7 @@ from mindspeed_rl.config_cls.validate_config import validate_rl_args from mindspeed_rl.utils import get_tokenizer from mindspeed_rl.datasets.build_dataset import build_train_valid_test_datasets from mindspeed_rl.utils import seed_all -from mindspeed_rl.utils.utils import MsProbe, get_node_nums +from mindspeed_rl.utils.utils import MsProbe, get_node_nums, print_config from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.utils import parse_args_from_config, init_torch_compile from mindspeed_rl.config_cls.megatron_config import MegatronConfig @@ -234,6 +234,9 @@ def parse_training_config(config: Dict): role="integrated" ) + print_config(rl_config) + print_config(generate_config) + return { "actor_config": actor_config, "ref_config": ref_config, @@ -304,13 +307,18 @@ def gpt_model_provider(pre_process, post_process): Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + Union[GPTModel]: The returned model """ from megatron.training import get_args from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.spec_utils import import_module from megatron.training.arguments import core_transformer_config_from_args + from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + from mindspeed_rl.models.base import gpt_model_forward + + GPTModel.forward = gpt_model_forward + args = get_args() logger.info('building GPT model ...') @@ -321,7 +329,9 @@ def gpt_model_provider(pre_process, post_process): transformer_layer_spec = import_module(args.spec) else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, qk_layernorm=args.qk_layernorm) - + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=False) model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, @@ -334,7 +344,8 @@ def gpt_model_provider(pre_process, post_process): share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, ) return model diff --git a/cli/train_ppo.py b/cli/train_ppo.py index 9682ac34..369f4277 100644 --- a/cli/train_ppo.py +++ b/cli/train_ppo.py @@ -19,7 +19,7 @@ from mindspeed_rl.config_cls.validate_config import validate_rl_args from mindspeed_rl.utils import get_tokenizer from mindspeed_rl.datasets.build_dataset import build_train_valid_test_datasets from mindspeed_rl.utils import seed_all -from mindspeed_rl.utils.utils import MsProbe, get_node_nums +from mindspeed_rl.utils.utils import MsProbe, get_node_nums, print_config from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.utils import parse_args_from_config, init_torch_compile from mindspeed_rl.config_cls.megatron_config import MegatronConfig @@ -264,6 +264,9 @@ def parse_training_config(config: Dict): role="integrated" ) + print_config(rl_config) + print_config(generate_config) + return { "actor_config": actor_config, "ref_config": ref_config, @@ -335,13 +338,18 @@ def gpt_model_provider(pre_process, post_process): Returns: - Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + Union[GPTModel]: The returned model """ from megatron.training import get_args from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.spec_utils import import_module from megatron.training.arguments import core_transformer_config_from_args + from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec + from mindspeed_rl.models.base import gpt_model_forward + + GPTModel.forward = gpt_model_forward + args = get_args() logger.info('building GPT model ...') @@ -352,7 +360,9 @@ def gpt_model_provider(pre_process, post_process): transformer_layer_spec = import_module(args.spec) else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) - + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=False) model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, @@ -365,7 +375,8 @@ def gpt_model_provider(pre_process, post_process): share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, ) return model diff --git a/mindspeed_rl/config_cls/generate_config.py b/mindspeed_rl/config_cls/generate_config.py index e3f46dbc..ffd10167 100644 --- a/mindspeed_rl/config_cls/generate_config.py +++ b/mindspeed_rl/config_cls/generate_config.py @@ -30,6 +30,7 @@ class GenerateConfig(BaseConfig): If False, we will use ACL graph and eager execution in hybrid for maximal performance and flexibility. torchair_graph: Whether to enable TorchAir graph optimization. If True, uses accelerated computational graph optimizations. enable_expert_parallel: Whether to enable expert parallel computation for Mixture-of-Experts (MoE) layers. + enable_sequence_parallelism: Whether to enable sequence parallelism. sampling_config: Configuration for text generation sampling. Default values are set for various sampling parameters. - num_completions: The number of independent completions to generate for each input prompt. Default is 1. - logprobs: The number of top tokens to return log probabilities for. Default is 1. @@ -85,6 +86,7 @@ class GenerateConfig(BaseConfig): self.enforce_eager = True self.torchair_graph = False self.enable_expert_parallel = False + self.enable_sequence_parallelism = False # 采样配置的默认值,用于生成文本时的采样策略设置 self.sampling_config = { diff --git a/mindspeed_rl/config_cls/megatron_config.py b/mindspeed_rl/config_cls/megatron_config.py index 50e6670f..f0208bfe 100644 --- a/mindspeed_rl/config_cls/megatron_config.py +++ b/mindspeed_rl/config_cls/megatron_config.py @@ -394,6 +394,9 @@ class MegatronConfig(BaseConfig): self.moe_aux_loss_coeff = 0.001 self.gemm_gradient_accumulation_fusion = False + self.mtp_num_layers = None + self.mtp_loss_scaling_factor = 0.0 + self.update(training_config, model_config) self.pad_to_multiple_of = self.tensor_model_parallel_size * self.context_parallel_size diff --git a/mindspeed_rl/models/actor_rollout_hybrid.py b/mindspeed_rl/models/actor_rollout_hybrid.py index a1874dd4..49d2977e 100644 --- a/mindspeed_rl/models/actor_rollout_hybrid.py +++ b/mindspeed_rl/models/actor_rollout_hybrid.py @@ -95,7 +95,7 @@ class ActorRolloutHybrid(ABC): **kwargs, ) else: - res = self.inference_actor.generate_sequences(prompts_list, **kwargs)[0] + res = self.inference_actor.generate_sequences(prompts_list, **kwargs) return res @mstx_timer_decorator diff --git a/mindspeed_rl/models/base/__init__.py b/mindspeed_rl/models/base/__init__.py index 9e76b6a2..d5faf9f0 100644 --- a/mindspeed_rl/models/base/__init__.py +++ b/mindspeed_rl/models/base/__init__.py @@ -1,2 +1,108 @@ # coding=utf-8 # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. +from typing import Optional + +import torch +from torch import Tensor + + +def gpt_model_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params=None, + packed_seq_params=None, + extra_block_kwargs: dict = None, + loss_mask: Optional[Tensor] = None, +) -> Tensor: + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + from megatron.training import get_args + args = get_args() + + if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): + args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + if self.mtp_process: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + compute_language_model_loss=self.compute_language_model_loss, + **(extra_block_kwargs or {}), + ) + + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + + if not self.post_process: + return hidden_states + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + if getattr(args, "task", False) and args.task[0] == 'needlebench': + hidden_states = hidden_states[-100:] + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None or args.mtp_num_layers >= 1: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + loss = self.compute_language_model_loss(labels, logits) + return loss diff --git a/mindspeed_rl/models/base/base_inference_engine.py b/mindspeed_rl/models/base/base_inference_engine.py index 966c8c67..24cdd836 100644 --- a/mindspeed_rl/models/base/base_inference_engine.py +++ b/mindspeed_rl/models/base/base_inference_engine.py @@ -9,6 +9,7 @@ class BaseInferEngine(ABC): including tokenizer information, parallel sizes during training and inference, model length limits, data types, GPU memory utilization, and trust settings for remote code. """ + def __init__( self, tokenizer_name_or_path: str, @@ -27,6 +28,7 @@ class BaseInferEngine(ABC): gpu_memory_utilization: float = 0.5, # Default value set to 0.5 trust_remote_code: bool = True, enable_expert_parallel: bool = False, + megatron_config=None, ): """ Initialize the base inference engine. @@ -63,7 +65,7 @@ class BaseInferEngine(ABC): self.gpu_memory_utilization = gpu_memory_utilization self.trust_remote_code = trust_remote_code self.enable_expert_parallel = enable_expert_parallel - + self.megatron_config = megatron_config @abstractmethod def init_cache_engine(self): diff --git a/mindspeed_rl/models/base/base_training_engine.py b/mindspeed_rl/models/base/base_training_engine.py index a964fac3..064ce666 100644 --- a/mindspeed_rl/models/base/base_training_engine.py +++ b/mindspeed_rl/models/base/base_training_engine.py @@ -176,7 +176,7 @@ class BaseTrainingEngine(ABC): elif self.use_remove_padding: input_ids, position_ids, process_batch, seqlens_in_batch, cu_seqlens_padded, index = self._get_forward_batch_info(batch_iter) - output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids) + output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids, labels=input_ids if self.megatron_config.mtp_num_layers else None) output.div_(self.temperature) return output, partial(self.loss_func.compute_loss, batch=process_batch, @@ -190,7 +190,7 @@ class BaseTrainingEngine(ABC): index=index) else: input_ids, position_ids, process_batch, index = self._get_forward_batch_info(batch_iter) - output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids) + output = model(input_ids=input_ids, attention_mask=None, position_ids=position_ids, labels=input_ids if self.megatron_config.mtp_num_layers else None) output.div_(self.temperature) return output, partial(self.loss_func.compute_loss, @@ -273,7 +273,20 @@ class BaseTrainingEngine(ABC): cu_seqlens_padded_ring = cu_seqlens_padded if self.megatron_config.cp_attention_mask_type == 'causal': cu_seqlens_padded_ring = (cu_seqlens_padded / get_ring_degree(self.megatron_config)).to(torch.int) - self.set_actual_seq_len(cu_seqlens_padded_ring.tolist()) + cu_seqlens_padded_list = cu_seqlens_padded_ring.tolist() + if self.megatron_config.mtp_num_layers: + seq_len = input_ids.shape[1] + mtp_res = [cu_seqlens_padded_list] + for i in range(1, self.megatron_config.mtp_num_layers + 1): + next_actual_seq_len = [] + for j in cu_seqlens_padded_list: + if j % seq_len == 0: + next_actual_seq_len.append(j) + else: + next_actual_seq_len.append(j - i) + mtp_res.append(next_actual_seq_len) + cu_seqlens_padded_list = mtp_res + self.set_actual_seq_len(cu_seqlens_padded_list) if cp_size > 1: input_ids, position_ids, batch, index = self._get_batch_data_with_cp(batch, input_ids, position_ids, labels) diff --git a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py index 5c10c992..51c58b1e 100644 --- a/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py +++ b/mindspeed_rl/models/rollout/vllm_adapter/megatron_weight_loaders.py @@ -185,6 +185,7 @@ def load_single_weight(params_dict, name, loaded_weight): def update_megatron_weight_loader(): + from torch.nn import Linear from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ReplicatedLinear) @@ -199,7 +200,8 @@ def update_megatron_weight_loader(): VocabParallelEmbedding: parallel_weight_loader, ParallelLMHead: parallel_weight_loader, ReplicatedLinear: parallel_weight_loader, - FusedMoE: parallel_weight_loader + FusedMoE: parallel_weight_loader, + Linear: parallel_weight_loader } for layer_class, weight_loader in LAYER_WEIGHT_MEGATRON_LOADER_REGISTRY.items(): diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 27e4cc71..38bac167 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -42,6 +42,7 @@ class VLLMInferEngine(BaseInferEngine): infer_pipeline_parallel_size: int, infer_expert_parallel_size: int, sampling_config: dict, + megatron_config: dict = None, prompt_type: str = None, prompt_type_path: str = None, enable_prefix_caching: bool = False, @@ -58,6 +59,7 @@ class VLLMInferEngine(BaseInferEngine): limit_mm_image_per_prompt: int = 1, limit_mm_video_per_prompt: int = 0, enable_expert_parallel: bool = False, + enable_sequence_parallelism: bool = False, **kwargs ): """ @@ -100,6 +102,7 @@ class VLLMInferEngine(BaseInferEngine): gpu_memory_utilization=gpu_memory_utilization, trust_remote_code=trust_remote_code, enable_expert_parallel=enable_expert_parallel, + megatron_config=megatron_config, ) # Additional initialization logic for VLLMInferEngine @@ -130,6 +133,10 @@ class VLLMInferEngine(BaseInferEngine): trust_remote_code=trust_remote_code ) + if self.megatron_config.mtp_num_layers and self.megatron_config.mtp_num_layers != self.hf_config.num_nextn_predict_layers: + raise ValueError(f"mtp_num_layers must equal to num_nextn_predict_layers, but got mtp_num_layers=" + f"{self.megatron_config.mtp_num_layers} and num_nextn_predict_layers={self.hf_config.num_nextn_predict_layers}") + self.tokenizer = get_tokenizer(tokenizer_name_or_path, prompt_type=prompt_type, prompt_type_path=prompt_type_path) self.pad_token_id = ( @@ -186,6 +193,15 @@ class VLLMInferEngine(BaseInferEngine): max_num_batched_tokens=max_num_batched_tokens, enable_expert_parallel=enable_expert_parallel, limit_mm_per_prompt=limit_mm_per_prompt_dict, + speculative_config={ + "method": "deepseek_mtp", + "num_speculative_tokens": self.megatron_config.mtp_num_layers, + } if self.megatron_config.mtp_num_layers else None, + compilation_config={ + "pass_config": { + "enable_sequence_parallelism": enable_sequence_parallelism + } + }, additional_config={ "torchair_graph_config": { "enabled": torchair_graph, @@ -200,6 +216,11 @@ class VLLMInferEngine(BaseInferEngine): self.engine = self.llm.llm_engine self.model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model() + self.mtp_model = None + if self.megatron_config.mtp_num_layers: + self.mtp_model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.drafter.model + self.model.model.layers.append(self.mtp_model.model.layers['2']) + self.cpu_model = {} for name, params in self.model.named_parameters(): @@ -268,7 +289,19 @@ class VLLMInferEngine(BaseInferEngine): def offload_model_weights(self): for name, params in self.model.named_parameters(): params.data = self.cpu_model[name] - if hasattr(self.model, 'model') and hasattr(self.model.model.layers[-1].self_attn, "mla_attn"): + + if self.mtp_model: + last_layer = -2 + mla = self.model.model.layers[-1].mtp_block.self_attn.mla_attn.impl + if hasattr(mla, "w_kc"): + mla.w_kc = None + mla.w_vc = None + if hasattr(mla, "W_UV"): + mla.W_UV = None + mla.W_UK_T = None + else: + last_layer = -1 + if hasattr(self.model, 'model') and hasattr(self.model.model.layers[last_layer].self_attn, "mla_attn"): for i in range(self.model.model.start_layer, self.model.model.end_layer): mla = self.model.model.layers[i].self_attn.mla_attn.impl if hasattr(mla, "w_kc"): @@ -355,33 +388,18 @@ class VLLMInferEngine(BaseInferEngine): prompt_ids = [torch.tensor(idx_list[indexes.index(index)])] index = [index] response_ids = self._post_process_outputs([output]) - yield (prompt_ids, *response_ids), index + yield (prompt_ids, response_ids), index if STOP_SIGNAL: self.engine.abort_request([request_id]) def _post_process_outputs(self, request_outputs): output_token_ids = [] - logprobs = [] for request_output in request_outputs: # List[RequestOutput] outputs = request_output.outputs for output in outputs: # List[CompletionOutput], usually len == 1 output_token_ids.append(torch.tensor(output.token_ids)) - logprobs_dicts = output.logprobs - if logprobs_dicts is None: - continue - - logprob = [] - for logprobs_dict, token_id in zip(logprobs_dicts, output.token_ids): - logprob.append(logprobs_dict[token_id].logprob) - logprobs.append(torch.tensor(logprob)) - - output_token_ids = pad_sequence(output_token_ids, batch_first=True, - padding_value=self.pad_token_id) - if len(logprobs) > 0: - logprobs = pad_sequence(logprobs, batch_first=True, - padding_value=self.pad_token_id) - return output_token_ids, logprobs + return output_token_ids @contextmanager def update_sampling_params(self, **kwargs): diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 9dff5afd..0dfeee66 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -21,6 +21,8 @@ import torch_npu import torch.distributed as dist from torch import Tensor +from mindspeed_rl.utils import Loggers + def get_node_nums(): nodes = ray.nodes() @@ -655,5 +657,30 @@ def _get_ip_by_ifname(): return None +def print_config(config): + logger = Loggers("print config") + + # 打印 rl_config 配置 + logger.info(f"\n\033[1m{config.__module__} CONFIGURATION\033[0m") + logger.info("=" * 50) + config = config.dict() + max_key_len = max(len(k) for k in config.keys()) + for key, value in config.items(): + if key in ['verifier_function', 'verifier_weight']: + formatted_value = ', '.join(map(str, value)) + logger.info(f"{key:<{max_key_len}} : [{formatted_value}]") + elif key == 'sampling_config': + continue + else: + logger.info(f"{key:<{max_key_len}} : {value}") + + # 子配置 sampling_config + sampling_config = config.get('sampling_config', {}) + if sampling_config: + max_sub_key_len = max(len(k) for k in sampling_config.keys()) + for key, value in sampling_config.items(): + logger.info(f" {key:<{max_sub_key_len}} : {value}") + + def is_multimodal(): return eval(os.getenv("IS_MULTIMODAL", "False")) diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index c22d68d7..3e91f1ac 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -339,9 +339,9 @@ class ActorHybridWorkerBase(BaseWorker): if batch_data and index: if self.rl_config.async_engine: logger.info(f"do async generate process.") - self.async_generate_process(batch_data, index, pad_token_id) + self.async_generate_process(batch_data, index) else: - self.sync_generate_process(batch_data, experience_count, index, pad_token_id) + self.sync_generate_process(batch_data, experience_count, index) if self.enable_partial_rollout: torch.distributed.barrier() end_time = time.time() @@ -372,7 +372,7 @@ class ActorHybridWorkerBase(BaseWorker): ) logger.info("finish generate_sequences") - def sync_generate_process(self, batch_data, experience_count, index, pad_token_id): + def sync_generate_process(self, batch_data, experience_count, index): if not self.enable_partial_rollout: indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt)) prompts_data = batch_data['prompts'][indexes] @@ -393,15 +393,13 @@ class ActorHybridWorkerBase(BaseWorker): prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm] if self.enable_partial_rollout: max_tokens = self.generate_config.sampling_config["max_tokens"] // self.rl_config.partial_rollout_max_split - responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), + responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), max_tokens=max_tokens, n=1, extra_info=batch_data) else: - responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), + responses = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), extra_info=batch_data) - responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id) - if is_multimodal(): prompts_data = batch_data['input_ids'][indexes].cpu().unbind() else: @@ -449,7 +447,7 @@ class ActorHybridWorkerBase(BaseWorker): MsProbe.save_data({"responses": responses, "prompts": prompts}) - def async_generate_process(self, batch_data, index, pad_token_id): + def async_generate_process(self, batch_data, index): self.actor_hybrid.inference_actor.init_cache_engine() prompts_data = batch_data['prompts'] prompt_length_data = batch_data['prompt_length'] @@ -480,7 +478,6 @@ class ActorHybridWorkerBase(BaseWorker): for samples, idx_output in response_generator: prompts, responses, log_probs = samples - responses = remove_padding_and_split_to_list(responses, self.tokenizer.eod, pad_token_id) remove_input_ids = False if self.enable_partial_rollout and len(responses[0]) == 1: @@ -636,6 +633,7 @@ class ActorHybridWorkerBase(BaseWorker): max_num_batched_tokens=self.generate_config.max_num_batched_tokens, limit_mm_image_per_prompt=self.generate_config.limit_mm_image_per_prompt, limit_mm_video_per_prompt=self.generate_config.limit_mm_video_per_prompt, + enable_sequence_parallelism=self.generate_config.enable_sequence_parallelism, ) return rollout diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index 729a602e..f8ce2ea3 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -85,6 +85,8 @@ class MegatronStyleVllmWeightContainer: ## vpp self._num_layer_list = self._build_num_layer_list(num_layer_list) + if self.model_config.num_nextn_predict_layers >= 1: + self._num_layer_list[-1] = self._num_layer_list[-1] + 1 self._vpp_rank = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK else 0 self._vpp_size = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE else 1 self._vpp_layer_list = self._build_vpp_layer_list(self._num_layer_list) diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index d8aa7726..a6fdeee7 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -243,14 +243,13 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): 'delete': ['q_a_proj']} self.params_mapping = [ # (megatron core gpt model name, vllm model name) - ("embedding.word_embeddings", "model.embed_tokens"), ("self_attention.linear_qkv", "self_attn.qkv_proj"), # q_a_proj, kv_a_proj_with_mqa ("self_attention.linear_proj", "self_attn.o_proj"), ("input_layernorm", "input_layernorm"), ("pre_mlp_layernorm", "post_attention_layernorm"), ("mlp.linear_fc1", "mlp.gate_up_proj"), ("mlp.linear_fc2", "mlp.down_proj"), - ("decoder.final_layernorm", "model.norm"), + (f"{'decoder.' if int(self.model_config.num_nextn_predict_layers) == 0 else ''}final_layernorm", "model.norm"), ("output_layer", "lm_head"), ("self_attention.linear_qb", "self_attn.q_b_proj"), ("self_attention.linear_kvb", "self_attn.kv_b_proj"), @@ -262,6 +261,12 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): ("mlp.experts.weight2", "mlp.experts.w2_weight"), ("self_attention.q_layernorm", "self_attn.q_a_layernorm"), ("self_attention.k_layernorm", "self_attn.kv_a_layernorm"), + ("mtp.final_layernorms.0.weight", "shared_head.norm.weight"), + ("output_layer.weight", "shared_head.head.weight"), + ("embedding.word_embeddings.weight", "embed_tokens.weight"), + ("enorm", "enorm"), + ("hnorm", "hnorm"), + ("eh_proj", "eh_proj"), ] def get_weight_buffer_meta(self, model, valid_names=None): @@ -312,6 +317,49 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): return weight_names_meta + def replace_name_i2t(self, inference_name): + """ + transfer inference weight name to training weight name + """ + for m_name, v_name in self.params_mapping: + if v_name not in inference_name: + continue + if "embed_tokens" in inference_name or "shared_head.norm" in inference_name or "shared_head.head" in inference_name: + return m_name + + if int(self.model_config.num_nextn_predict_layers) == 1 and f".{self.model_config.num_hidden_layers}." in inference_name: + inference_name = inference_name.replace("model", "mtp") + inference_name = inference_name.replace("mtp_block", "transformer_layer") + inference_name = inference_name.replace(str(self.model_config.num_hidden_layers), "0") + if "enorm" in inference_name or "hnorm" in inference_name or "eh_proj" in inference_name: + param_name = inference_name + else: + vllm_name_list = inference_name.split(".") + param_name_list = vllm_name_list[:4] + weight_or_bias = vllm_name_list[-1] + param_name_list.append(m_name) + if weight_or_bias in ['weight', 'bias']: + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + elif "layers" in inference_name: # deal with decoder layers + inference_name = inference_name.replace("model", "decoder") + vllm_name_list = inference_name.split(".") + if "layer_norm_weight" in vllm_name_list or "layer_norm_bias" in vllm_name_list: + param_name_list = vllm_name_list[:3] + param_name_list.append(m_name) + param_name = ".".join(param_name_list) + else: + param_name_list = vllm_name_list[:3] + weight_or_bias = vllm_name_list[-1] + param_name_list.append(m_name) + if weight_or_bias in ['weight', 'bias']: + param_name_list.append(weight_or_bias) + param_name = ".".join(param_name_list) + else: + param_name = inference_name.replace(v_name, m_name) + return param_name + + class QwenMVWeightAdaptor(MegatronVLLMWeightAdaptor): """ -- Gitee From 25386621f95b3c0464319b374dad93fbae0c5cba Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Thu, 7 Aug 2025 01:29:02 +0000 Subject: [PATCH 2/6] 1 Signed-off-by: ningbenzhe1 --- mindspeed_rl/models/rollout/vllm_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index 38bac167..ba4674ec 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -219,7 +219,7 @@ class VLLMInferEngine(BaseInferEngine): self.mtp_model = None if self.megatron_config.mtp_num_layers: self.mtp_model = self.llm.llm_engine.model_executor.driver_worker.worker.model_runner.drafter.model - self.model.model.layers.append(self.mtp_model.model.layers['2']) + self.model.model.layers.append(self.mtp_model.model.layers[f'{self.megatron_config.num_layers}']) self.cpu_model = {} -- Gitee From 222223e0215a8fcf016613a431a7642e64b04310 Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Thu, 7 Aug 2025 01:31:11 +0000 Subject: [PATCH 3/6] fix mtp key Signed-off-by: ningbenzhe1 --- mindspeed_rl/workers/resharding/weight_adaptor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index a6fdeee7..fa33b435 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -327,13 +327,16 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): if "embed_tokens" in inference_name or "shared_head.norm" in inference_name or "shared_head.head" in inference_name: return m_name - if int(self.model_config.num_nextn_predict_layers) == 1 and f".{self.model_config.num_hidden_layers}." in inference_name: + has_mtp_special_keys = "enorm" in inference_name or "hnorm" in inference_name or "eh_proj" in inference_name + has_mtp_block = "mtp_block" in inference_name + + if int(self.model_config.num_nextn_predict_layers) == 1 and (has_mtp_special_keys or has_mtp_block): inference_name = inference_name.replace("model", "mtp") - inference_name = inference_name.replace("mtp_block", "transformer_layer") inference_name = inference_name.replace(str(self.model_config.num_hidden_layers), "0") - if "enorm" in inference_name or "hnorm" in inference_name or "eh_proj" in inference_name: + if has_mtp_special_keys: param_name = inference_name else: + inference_name = inference_name.replace("mtp_block", "transformer_layer") vllm_name_list = inference_name.split(".") param_name_list = vllm_name_list[:4] weight_or_bias = vllm_name_list[-1] -- Gitee From 1b42f2a73a2b0360b9e1f35b1ccdef75e38690aa Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Thu, 7 Aug 2025 03:37:06 +0000 Subject: [PATCH 4/6] FIX Signed-off-by: ningbenzhe1 --- mindspeed_rl/workers/resharding/weight_adaptor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspeed_rl/workers/resharding/weight_adaptor.py b/mindspeed_rl/workers/resharding/weight_adaptor.py index fa33b435..a0320a8e 100644 --- a/mindspeed_rl/workers/resharding/weight_adaptor.py +++ b/mindspeed_rl/workers/resharding/weight_adaptor.py @@ -332,11 +332,12 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): if int(self.model_config.num_nextn_predict_layers) == 1 and (has_mtp_special_keys or has_mtp_block): inference_name = inference_name.replace("model", "mtp") - inference_name = inference_name.replace(str(self.model_config.num_hidden_layers), "0") + inference_name = inference_name.replace("mtp_block", "transformer_layer") + vllm_name_list = inference_name.split(".") + inference_name = inference_name.replace(str(vllm_name_list[2]), "0") if has_mtp_special_keys: param_name = inference_name else: - inference_name = inference_name.replace("mtp_block", "transformer_layer") vllm_name_list = inference_name.split(".") param_name_list = vllm_name_list[:4] weight_or_bias = vllm_name_list[-1] @@ -363,7 +364,6 @@ class DeepSeekMVWeightAdaptor(MegatronVLLMWeightAdaptor): return param_name - class QwenMVWeightAdaptor(MegatronVLLMWeightAdaptor): """ Megatron-vLLM WeightAdaptor for Qwen model architectures. -- Gitee From ec83963a3f7b659f87be2093ae0778c7b2f8d8e9 Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Thu, 7 Aug 2025 08:16:39 +0000 Subject: [PATCH 5/6] 1 Signed-off-by: ningbenzhe1 --- .../resharding/vllm_weight_container.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index f8ce2ea3..e2705c71 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -639,8 +639,11 @@ class MegatronStyleVllmWeightContainer: else: partition_dim = get_tensor_parallel_partition_dim(param) infer_params = torch.cat(infer_params, dim=partition_dim) - split_params = torch.chunk(infer_params, self._infer_tp_size, dim=partition_dim) - new_params_list = list(split_params) + if not "eh_proj" in name: + split_params = torch.chunk(infer_params, self._infer_tp_size, dim=partition_dim) + new_params_list = list(split_params) + else: + new_params_list = [infer_params] * self._infer_tp_size # make_list param_list = new_params_list @@ -653,7 +656,7 @@ class MegatronStyleVllmWeightContainer: return param_list[infer_tp_rank_in_group] def allgather_tp_param(self, param, name): - if self._tp_size <= self._infer_tp_size: + if self._tp_size < self._infer_tp_size: return param tp_allgather_size = get_tp_allgather_world_size() tp_allgather_group = get_tp_allgather_group() @@ -668,9 +671,12 @@ class MegatronStyleVllmWeightContainer: torch.distributed.all_gather(infer_param, param, group=tp_allgather_group) if self.enable_validate: update_md5_by_rank(infer_param, param, self.origin_params_for_md5, self.infer_params_for_md5) - part_len = len(infer_param) // self._infer_tp_size - start = self._rank % self._infer_tp_size - part_param = infer_param[part_len * start:part_len * (start + 1)] - infer_param = self._default_tp_concat_fn(name, param, part_param) + if "eh_proj" in name: + infer_param = self._default_tp_concat_fn(name, param, infer_param) + else: + part_len = len(infer_param) // self._infer_tp_size + start = self._rank % self._infer_tp_size + part_param = infer_param[part_len * start:part_len * (start + 1)] + infer_param = self._default_tp_concat_fn(name, param, part_param) return infer_param -- Gitee From d4e7479d9799bb637256e54c9e53ebd6ac74bde7 Mon Sep 17 00:00:00 2001 From: ningbenzhe1 Date: Thu, 7 Aug 2025 08:17:24 +0000 Subject: [PATCH 6/6] 1 Signed-off-by: ningbenzhe1 --- mindspeed_rl/models/rollout/vllm_engine.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mindspeed_rl/models/rollout/vllm_engine.py b/mindspeed_rl/models/rollout/vllm_engine.py index ba4674ec..a7911c5f 100644 --- a/mindspeed_rl/models/rollout/vllm_engine.py +++ b/mindspeed_rl/models/rollout/vllm_engine.py @@ -331,6 +331,15 @@ class VLLMInferEngine(BaseInferEngine): mla.W_UV = None mla.W_UK_T = None mla.process_weights_after_loading(None) + if self.mtp_model: + mla = self.model.model.layers[-1].mtp_block.self_attn.mla_attn.impl + if hasattr(mla, "w_kc"): + mla.w_kc = None + mla.w_vc = None + if hasattr(mla, "W_UV"): + mla.W_UV = None + mla.W_UK_T = None + mla.process_weights_after_loading(None) @torch.no_grad() def generate_sequences(self, idx_list, **kwargs): -- Gitee