From aa95535aafd7109e6ff74eff5efb91fc09c49346 Mon Sep 17 00:00:00 2001 From: xiecheng Date: Mon, 8 Sep 2025 18:36:22 +0800 Subject: [PATCH] update 32b_yaml and remove md5 --- configs/dapo_qwen25_32b_A3_32k.yaml | 4 +- docs/features/deterministic_computation.md | 2 +- mindspeed_rl/config_cls/rl_config.py | 2 - mindspeed_rl/workers/actor_hybrid_worker.py | 1 - .../resharding/megatron_sharding_manager.py | 4 -- mindspeed_rl/workers/resharding/utils.py | 43 ------------------- .../resharding/vllm_weight_container.py | 29 +------------ tests/st/resharding/test_resharding.py | 1 - 8 files changed, 5 insertions(+), 81 deletions(-) diff --git a/configs/dapo_qwen25_32b_A3_32k.yaml b/configs/dapo_qwen25_32b_A3_32k.yaml index 4dbe84b0..63654577 100644 --- a/configs/dapo_qwen25_32b_A3_32k.yaml +++ b/configs/dapo_qwen25_32b_A3_32k.yaml @@ -55,8 +55,8 @@ actor_config: finetune: false load: ./Qwen2.5-32B-tp8-pp2 save: ./ckpt - no_load_optim: false - no_load_rng: false + no_load_optim: true + no_load_rng: true context_parallel_size: 2 context_parallel_algo: megatron_cp_algo attention_mask_type: causal diff --git a/docs/features/deterministic_computation.md b/docs/features/deterministic_computation.md index 13c0e1b5..02627e14 100644 --- a/docs/features/deterministic_computation.md +++ b/docs/features/deterministic_computation.md @@ -45,7 +45,7 @@ DAPO场景下使能确定性计算,应配置动态采样参数为false,才能保证每轮迭代的数据输入是一致的。 ``` - filter_groups_enable: true + filter_groups_enable: false ``` ### 2. 使能确定性计算参数 diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index cba71c87..63e4044d 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -44,7 +44,6 @@ class RLConfig(BaseConfig): shuffle_mini_batch: Whether to shuffle minibatch (default: False) n_samples_per_prompt: Number of samples per prompt (default: 1) - enable_sharding_validate: Whether to enable sharding validation (default: False) tp_split_expert: use_tensorboard: Whether to use tensorboard (default: False) use_wandb: Whether to use wandb (default: False) @@ -98,7 +97,6 @@ class RLConfig(BaseConfig): self.use_kl_in_reward = False self.shuffle_mini_batch = False - self.enable_sharding_validate = False self.tp_split_expert = False self.use_tensorboard = False diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 2e879458..16cfc22a 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -777,7 +777,6 @@ class ActorHybridWorkerBase(BaseWorker): optimizer_offload=self.generate_config.offload_train_optimizer, grad_offload=self.generate_config.offload_train_grad, train_param_offload=self.generate_config.offload_train_param, - enable_validate=self.rl_config.enable_sharding_validate, megatron_offloader=self.actor_offloader, noop_layers=self.megatron_config.noop_layers ) diff --git a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py index bc35f600..5ea6e30d 100644 --- a/mindspeed_rl/workers/resharding/megatron_sharding_manager.py +++ b/mindspeed_rl/workers/resharding/megatron_sharding_manager.py @@ -40,7 +40,6 @@ class MegatronShardingManager: optimizer_offload=False, grad_offload=False, train_param_offload=False, - enable_validate=False, megatron_model=None, model_config=None, infer_tensor_parallel_size=None, @@ -59,7 +58,6 @@ class MegatronShardingManager: optimizer (MegatronOptimizer): Optimizer instance used for model training. optimizer_offload (bool): Whether to offload optimizer operations to a separate device. grad_offload (bool): whether to offload gradient computation to CPU during training. - enable_validate (bool): Whether to enable communication data validate. megatron_model (nn.Module or nn.ModuleList): Megatron model instance. model_config (MegatronConfig): Configuration for the model. infer_tensor_parallel_size (int): Tensor parallel size during inference. @@ -86,7 +84,6 @@ class MegatronShardingManager: moe_tp_extend_ep=moe_tp_extend_ep, parallel_state=parallel_state, weight_adaptor=self.weight_adaptor, - enable_validate=enable_validate, noop_layers=noop_layers, eplb_map=self.inference_engine.eplb_map, global_redundant_expert_num=self.inference_engine.global_redundant_expert_num, @@ -96,7 +93,6 @@ class MegatronShardingManager: self.optimizer_offload = optimizer_offload self.grad_offload = grad_offload self.train_param_offload = train_param_offload - self.enable_validate = enable_validate self.inference_engine.offload_model_weights() self.megatron_offloader = megatron_offloader diff --git a/mindspeed_rl/workers/resharding/utils.py b/mindspeed_rl/workers/resharding/utils.py index cd5d5369..ff395a60 100644 --- a/mindspeed_rl/workers/resharding/utils.py +++ b/mindspeed_rl/workers/resharding/utils.py @@ -54,48 +54,5 @@ def get_tensor_parallel_partition_dim(param): return param.partition_dim -def tp_md5_validate(infer_params_for_md5, origin_params_for_md5, log_prefix): - md5_tensor = bytes_to_tensor(origin_params_for_md5) - origin_params_md5_allgather_tensor = [] - for _ in range(get_tp_allgather_world_size()): - origin_params_md5_allgather_tensor.append(torch.empty_like(md5_tensor)) - torch.distributed.all_gather(origin_params_md5_allgather_tensor, md5_tensor, group=get_tp_allgather_group()) - for index, params in enumerate(infer_params_for_md5): - recv_md5_tensor = bytes_to_tensor(params) - validate_md5(origin_params_md5_allgather_tensor[index], recv_md5_tensor, log_prefix) - - -def update_md5_by_rank(infer_param, param, origin_params_for_md5, infer_params_for_md5): - # compute current param' md5 value at current rank - param_bytes = param.data.to(torch.float32).cpu().numpy().tobytes() - origin_params_for_md5.update(param_bytes) - # Calculate the md5 values of all received params in the TP group, separated by rank - for index, recv_param in enumerate(infer_param): - recv_param_bytes = recv_param.data.to(torch.float32).cpu().numpy().tobytes() - infer_params_for_md5[index].update(recv_param_bytes) - - -def bytes_to_tensor(bytes_data): - md5_tensor = torch.tensor([int(h, 16) for h in bytes_data.hexdigest()], dtype=torch.int64, - device=torch.cuda.current_device()) - return md5_tensor - - -def compute_md5(model): - hash_value = hashlib.md5() - for memory_buffer in model.memory_buffers.values(): - param_bytes = memory_buffer.data.detach().to(torch.float32).cpu().numpy().tobytes() - hash_value.update(param_bytes) - md5_tensor = bytes_to_tensor(hash_value) - return md5_tensor - - -def validate_md5(md5_tensor_src, md5_tensor, log_prefix): - if torch.equal(md5_tensor_src, md5_tensor): - logging.info(f"{log_prefix} md5 validate Hash: The weights of the two models match.") - else: - logging.info(f"{log_prefix} md5 validate Hash: The weights of the two models do not match.") - - def is_fake_tp_param(name, moe_tp_extended_ep): return 'mlp.experts.weight' in name and moe_tp_extended_ep diff --git a/mindspeed_rl/workers/resharding/vllm_weight_container.py b/mindspeed_rl/workers/resharding/vllm_weight_container.py index b99d36f8..0d8c206b 100644 --- a/mindspeed_rl/workers/resharding/vllm_weight_container.py +++ b/mindspeed_rl/workers/resharding/vllm_weight_container.py @@ -30,8 +30,8 @@ import vllm.distributed.parallel_state as ps from mindspeed_rl.workers.resharding.memory_buffer import build_model_weight_buffer, calc_padded_numel import mindspeed_rl.workers.resharding.utils -from mindspeed_rl.workers.resharding.utils import get_tensor_parallel_partition_dim, tp_md5_validate, \ - update_md5_by_rank, compute_md5, validate_md5, _build_infer_param_dict, get_tp_allgather_group, \ +from mindspeed_rl.workers.resharding.utils import get_tensor_parallel_partition_dim, \ + _build_infer_param_dict, get_tp_allgather_group, \ get_tp_allgather_world_size, is_tensor_parallel_param, get_tp_group, is_fake_tp_param from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.utils import is_multimodal @@ -48,7 +48,6 @@ class MegatronStyleVllmWeightContainer: moe_tp_extend_ep=False, parallel_state=None, weight_adaptor=None, - enable_validate=False, noop_layers=None, eplb_map=None, global_redundant_expert_num=0, @@ -66,7 +65,6 @@ class MegatronStyleVllmWeightContainer: moe_tp_extend_ep (bool): Controls whether expert model parameters are split across multiple GPUs. parallel_state (ModuleType): Megatron parallel state of the model. weight_adaptor (WeightAdaptor): Provides a set of tools to transfer from training weight to inference weight. - enable_validate (bool): Whether to enable communication data validate. """ self.vllm_model = vllm_model @@ -145,11 +143,6 @@ class MegatronStyleVllmWeightContainer: # validate parallel configs self._validate_parallel_config() - # md5 validate - self.enable_validate = enable_validate - self.origin_params_for_md5 = None - self.infer_params_for_md5 = None - self._rank = dist.get_rank() self._init_tensor_model_parallel_allgather_group() self._init_pipeline_model_parallel_allgather_group() @@ -450,10 +443,6 @@ class MegatronStyleVllmWeightContainer: name_pairs = sorted(list(set([(name, vpp_rank, self.weight_adaptor.replace_name_i2t(normal_layer_func(name, vpp_rank=vpp_rank))) for vpp_rank, names_per_vpp in enumerate(weight_names_meta) for name in names_per_vpp]))) - if self.enable_validate: - self.origin_params_for_md5 = hashlib.md5() - self.infer_params_for_md5 = [hashlib.md5() for _ in range(get_tp_allgather_world_size())] - # 检查 linear_fc1 和 linear_fc2 权重形状是否符合特定关系(fc1 包含门控和扩展参数,因此大小是 fc2 的两倍)。不符合条件的模型不被支持。 for _, vpp_rank, megatron_name in name_pairs: if not megatron_name.startswith("image_encoder") and megatron_name.endswith("linear_fc1.weight"): @@ -479,10 +468,6 @@ class MegatronStyleVllmWeightContainer: param = _transfer_from_megatron_division(megatron_param, megatron_name) weight_buffer.copy_by_name(hf_name, param) - # tp md5 validate - if self.enable_validate: - tp_md5_validate(self.infer_params_for_md5, self.origin_params_for_md5, - f"rank[{self._rank}] tp params allgather") def _update_weight_buffers_ep(self): # 构造临时的experts_memory_buffers @@ -556,14 +541,6 @@ class MegatronStyleVllmWeightContainer: global_src = dist.get_global_rank(group=self._pp_group, group_rank=cur_pp_rank) for memory_buffer in self.weight_buffers[cur_pp_rank].memory_buffers.values(): dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self._pp_group, async_op=False) - if self.enable_validate: - md5_tensor = compute_md5(self.weight_buffers[cur_pp_rank]) - if self._rank == global_src: - dist.broadcast(md5_tensor, group=self._pp_group, src=global_src, async_op=False) - else: - md5_tensor_src = torch.zeros_like(md5_tensor, dtype=torch.int64, device=torch.cuda.current_device()) - dist.broadcast(md5_tensor_src, group=self._pp_group, src=global_src, async_op=False) - validate_md5(md5_tensor_src, md5_tensor, f"rank[{self._rank}] pp resharding params") def get_expert_router(self, cur_rank, train_tp_ep_size, infer_tp_ep_size, world_size): @@ -759,8 +736,6 @@ class MegatronStyleVllmWeightContainer: # allocate a new tensor with proper size infer_param = [torch.empty_like(param) for _ in range(tp_allgather_size)] torch.distributed.all_gather(infer_param, param, group=tp_allgather_group) - if self.enable_validate: - update_md5_by_rank(infer_param, param, self.origin_params_for_md5, self.infer_params_for_md5) part_len = len(infer_param) // self._infer_tp_size start = self._rank % self._infer_tp_size part_param = infer_param[part_len * start:part_len * (start + 1)] diff --git a/tests/st/resharding/test_resharding.py b/tests/st/resharding/test_resharding.py index 2a9f3227..a7da3fd6 100644 --- a/tests/st/resharding/test_resharding.py +++ b/tests/st/resharding/test_resharding.py @@ -360,7 +360,6 @@ class TestActor(): optimizer_offload=True, grad_offload=True, train_param_offload=True, - enable_validate=False, megatron_offloader=self.megatron_offloader ) torch.cuda.empty_cache() -- Gitee