diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 7abfbe8f10477327802b25a5ad856bffc00c60f0..e9fedd01aad20405bfaa1b3655aaf90f59d29a83 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -50,7 +50,8 @@ class RLConfig(BaseConfig): blocking: Whether to enable blocking mode (default: False) num_cpus_for_local_task: Number of CPUs for local ray task (default: 1) num_cpus_for_placement_group: Number of CPUs for ray worker placement group - + ras_flag: open ras() or not (default: False) + save_checkpoint_count: save checkpoint count (default: 2) is_multimodal: Whether base model is a multimodal model or not (default: False) # Default values can still be defined if no config is provided @@ -117,6 +118,10 @@ class RLConfig(BaseConfig): self.use_dynamic_bsz = False self.max_packing_token_size = 4096 + # Ras¿ØÖÆ + self.ras_flag = True + self.save_checkpoint_count = 2 + if config_dict.get("actor_resource") is not None: for key, _ in config_dict["actor_resource"].items(): if key not in self.actor_resource: diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 84a41d2da302b1141593d0e18e031dd5a8296aa8..edc5a66c76940b4ba9cca0763f4153a7d4d5007f 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -1,14 +1,18 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import time +import os import dataclasses import copy import gc from typing import Callable +import threading import ray from torch import nn +import random import torch +import numpy as np from transformers import AutoConfig from mindspeed_rl.config_cls.megatron_config import MegatronConfig @@ -21,11 +25,15 @@ from mindspeed_rl.models.rollout.vllm_engine import VLLMInferEngine from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.utils.utils import MsProbe from mindspeed_rl.workers.base_worker import BaseWorker +from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronShardingManager, MegatronOffLoader from mindspeed_rl.utils.utils import num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator, profiler_start, profiler_step, is_multimodal from mindspeed_rl.utils.pad_process import remove_padding_and_split_to_list, truncate_rows +logger = Loggers('actor_hybrid_worker') + + class ActorHybridWorkerBase(BaseWorker): """ ActorHybridWorker class. This class implements the hybrid worker logic for training and inference. @@ -130,7 +138,20 @@ class ActorHybridWorkerBase(BaseWorker): self.empty_cache() def get_iteration(self): - return self.args.iteration + if self.rl_config.ras_flag is not True: + return self.args.iteration + + tracker_filename = os.path.join(self.megatron_config.load, 'latest_checkpointed_iteration.txt') + if not os.path.exists(tracker_filename): + return self.args.iteration + with open(tracker_filename, 'r') as f: + meta_string = f.read().strip() + try: + iteration = int(meta_string) + except ValueError: + return self.args.iteration + logger.info(f"get iteration: {iteration - 1}") + return iteration - 1 def get_consumed_train_samples(self): return self.args.consumed_train_samples @@ -221,11 +242,335 @@ class ActorHybridWorkerBase(BaseWorker): ) profiler_step(self.actor_profiler) + def get_checkpoint_name(self, checkpoints_path, iteration, + pipeline_parallel=None, + tensor_rank=None, pipeline_rank=None, + expert_parallel=None, expert_rank=None): + directory = 'iter_{:07d}'.format(iteration) + # Use both the tensor and pipeline MP rank. + if pipeline_parallel is None: + pipeline_parallel = (self.parallel_state.get_pipeline_model_parallel_world_size() > 1) + if tensor_rank is None: + tensor_rank = self.parallel_state.get_tensor_model_parallel_rank() + if pipeline_rank is None: + pipeline_rank = self.parallel_state.get_pipeline_model_parallel_rank() + if expert_parallel is None: + expert_parallel = (self.parallel_state.get_expert_model_parallel_world_size() > 1) + if expert_rank is None: + expert_rank = self.parallel_state.get_expert_model_parallel_rank() + + if not pipeline_parallel: + common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}') + else: + common_path = os.path.join(checkpoints_path, directory, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}') + + if expert_parallel: + common_path = common_path + f'_{expert_rank:03d}' + + return os.path.join(common_path, "model_optim_rng.pt") + + def generate_state_dict(self, args, model, iteration, optimizer, opt_param_scheduler, rng_state): + # Arguments, iteration, and model. + state_dict = {} + state_dict['args'] = args + state_dict['checkpoint_version'] = 3.0 + state_dict['iteration'] = iteration + + def recursive_to_cpu(obj, name: str): + if isinstance(obj, torch.Tensor): + return obj.cpu() + elif isinstance(obj, dict): + return {k: recursive_to_cpu(v, name) for k, v in obj.items()} + elif isinstance(obj, list): + return [recursive_to_cpu(v, name) for v in obj] + else: + return obj + + npu_stat_dict_checkpoint = model.state_dict_for_save_checkpoint() + state_dict['model'] = recursive_to_cpu(npu_stat_dict_checkpoint, 'model') + + if optimizer is not None: + optim_stat_dict = optimizer.state_dict() + state_dict['optimizer'] = recursive_to_cpu(optim_stat_dict, 'optim') + + if opt_param_scheduler is not None: + param_stat_dict = opt_param_scheduler.state_dict() + state_dict['opt_param_scheduler'] = recursive_to_cpu(param_stat_dict, "param") + + state_dict["rng_state"] = rng_state + + return state_dict + + def get_rng_state(self): + """ collect rng state across data parallel ranks """ + from megatron.core import mpu, tensor_parallel + + args = self.args + rng_state = { + 'random_rng_state': random.getstate(), + 'np_rng_state': np.random.get_state(), + 'torch_rng_state': torch.get_rng_state(), + 'cuda_rng_state': torch.cuda.get_rng_state(), + 'rng_tracker_states': tensor_parallel.get_cuda_rng_tracker().get_states()} + + rng_state_list = None + if torch.distributed.is_initialized() and \ + mpu.get_data_parallel_world_size() > 1 and \ + args.data_parallel_random_init: + rng_state_list = \ + [None for i in range(mpu.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + rng_state_list, + rng_state, + group=mpu.get_data_parallel_group()) + else: + rng_state_list = [rng_state] + + return rng_state_list + + @staticmethod + def check_file_integrity(filename: str) -> int: + max_wait_time = 300 + start_time = time.time() + while True: + try: + if os.path.exists(filename) and os.path.getsize(filename) > 0: + with open(filename, 'rb') as f: + header = f.read(1024) + if b'PK\x03\x04' in header: + return 0 + except Exception as e: + pass + if time.time() - start_time > max_wait_time: + logger.error(f"save checkpoint file {filename} failed") + return 1 + time.sleep(0.1) + + @staticmethod + def save_checkpoints_file(filename: str, state_dict): + dirname = os.path.dirname(filename) + os.makedirs(dirname, exist_ok=True) + torch.save(state_dict, filename) + + @staticmethod + def delete_directory(path: str): + if not os.path.exists(path): + return + for root, directories, files in os.walk(path, topdown=False): + for file in files: + file_path = os.path.join(root, file) + try: + os.remove(file_path) + except Exception as e: + logger.error(f"delete file {file_path} failed") + for directory in directories: + dir_path = os.path.join(root, directory) + try: + os.rmdir(dir_path) + except Exception as e: + logger.error(f"delete directory {dir_path} failed") + try: + os.rmdir(path) + except Exception as e: + logger.error(f"delete directory {path} failed") + + def get_all_check_directories(self, iteration): + dp_size = self._world_size // self.parallel_state.get_data_parallel_world_size() + pp_size = self.parallel_state.get_pipeline_model_parallel_world_size() + ep_size = self.parallel_state.get_expert_model_parallel_world_size() + + if pp_size > 1: + if ep_size > 1: + all_dir_names = [f"{self.args.save}/iter_{iteration:07d}/mp_rank_{i:02d}_{j:03d}_{k:03d}" + for i in range(dp_size) for j in range(pp_size) for k in range(ep_size)] + else: + all_dir_names = [f"{self.args.save}/iter_{iteration:07d}/mp_rank_{i:02d}_{j:03d}" + for i in range(dp_size) for j in range(pp_size)] + else: + if ep_size > 1: + all_dir_names = [f"{self.args.save}/iter_{iteration:07d}/mp_rank_{i:02d}_{j:03d}" + for i in range(dp_size) for j in range(ep_size)] + else: + all_dir_names = [f"{self.args.save}/iter_{iteration:07d}/mp_rank_{i:02d}" + for i in range(dp_size)] + + return all_dir_names + + def check_all_dp_ckpt_files(self, iteration) -> bool: + all_dir_names = self.get_all_check_directories(iteration) + + if self.megatron_config.use_distributed_optimizer is True: + files = ["distrib_optim.pt", "model_optim_rng.pt"] + else: + files = ["model_optim_rng.pt"] + + max_wait_time = 300 + start_time = time.time() + while True: + success_flag = True + for dir_name in all_dir_names: + if not os.path.exists(dir_name): + success_flag = False + break + + for file in files: + file_name = os.path.join(dir_name, file) + try: + if not os.path.exists(file_name) or os.path.getsize(file_name) == 0: + success_flag = False + break + with open(file_name, 'rb') as f: + header = f.read(1024) + if b'PK\x03\x04' not in header: + success_flag = False + break + except Exception as e: + pass + + if success_flag: + return True + if time.time() - start_time > max_wait_time: + return False + time.sleep(0.1) + + def async_save_distribute_ckpt(self, distribute_optim, iteration: int): + model_rng_checkpoint_name = self.get_checkpoint_name(self.args.save, iteration) + if self.megatron_config.use_distributed_optimizer is True: + distribute_optim_ckpt_name = os.path.join(os.path.dirname(model_rng_checkpoint_name), "distrib_optim.pt") + logger.info(f"set checkpoint distribute optim name: {distribute_optim_ckpt_name}") + self.save_checkpoints_file(distribute_optim_ckpt_name, distribute_optim) + ckpt_files = [model_rng_checkpoint_name, distribute_optim_ckpt_name] + else: + ckpt_files = [model_rng_checkpoint_name] + for ckpt_file in ckpt_files: + if self.check_file_integrity(ckpt_file) != 0: + return + + # write iteration result + if self._rank == 0: + if self.check_all_dp_ckpt_files(iteration) is not True: + return + + tracker_filename = os.path.join(self.args.save, "latest_checkpointed_iteration.txt") + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + if (self.rl_config.save_checkpoint_count and + iteration > self.rl_config.save_checkpoint_count * self.megatron_config.save_interval): + delete_iteration = iteration - self.rl_config.save_checkpoint_count * self.megatron_config.save_interval + directory = os.path.join(self.args.save, 'iter_{:07d}'.format(delete_iteration)) + self.delete_directory(directory) + logger.info(f"delete old checkpoint directory: {directory}") + + def async_save_model_rng_ckpt(self, iteration: int): + model_rng_ckpt_name = self.get_checkpoint_name(self.args.save, iteration) + logger.info(f"set checkpoint model_rng name: {model_rng_ckpt_name}") + stat_dict = self.generate_state_dict(self.args, self.model[0], iteration, + self.optimizer, self.opt_param_scheduler, self.get_rng_state()) + stat_dict['num_floating_point_operations_so_far'] = self.num_floating_point_operations_so_far + self.save_checkpoints_file(model_rng_ckpt_name, stat_dict) + + def construct_local_shards(self): + optimizer = self.optimizer + + for gbuf_idx, gbuf_range_maps in enumerate(optimizer.gbuf_ranges): + for _, gbuf_range_map_for_all_buckets in gbuf_range_maps.items(): + buffer_numel_unpadded = optimizer.buffers[gbuf_idx].numel_unpadded + for bucket_idx, gbuf_range_map in enumerate(gbuf_range_map_for_all_buckets): + gbuf_world_numel = optimizer.buffers[gbuf_idx].buckets[bucket_idx].grad_data.numel() + gbuf_local_numel = gbuf_world_numel // optimizer.data_parallel_group_gloo.size() + gbuf_world_numel_unpadded = optimizer.buffers[gbuf_idx].buckets[bucket_idx].numel_unpadded + + local_shards = { + key: torch.zeros((gbuf_local_numel,), dtype=torch.float32, device="cpu") + for key in ("param", "exp_avg", "exp_avg_sq") + } + + # Build contiguous DP rank shards (for param + optim states). + for model_param, param_range_map in gbuf_range_map["param_map"].items(): + group_index, group_order = optimizer.model_param_group_index_map[model_param] + main_param = optimizer.optimizer.param_groups[group_index]["params"][group_order] + optim_state = optimizer.optimizer.state[main_param] + + tensors = {"param": main_param, **optim_state} + gbuf_local_start = param_range_map["gbuf_local"].start + gbuf_local_end = param_range_map["gbuf_local"].end + + for key in local_shards: + local_shards[key][gbuf_local_start:gbuf_local_end].data.copy_(tensors[key].detach().cpu()) + + return local_shards, gbuf_world_numel, gbuf_world_numel_unpadded, buffer_numel_unpadded + + def cpu_gather_all_shards(self, local_shards, gbuf_world_numel, gbuf_world_numel_unpadded, buffer_numel_unpadded): + optimizer = self.optimizer + data_parallel_world_size = optimizer.data_parallel_group_gloo.size() + data_parallel_group_gloo = optimizer.data_parallel_group_gloo + data_parallel_rank = torch.distributed.get_rank(optimizer.data_parallel_group_gloo) + data_parallel_global_ranks = torch.distributed.get_process_group_ranks(optimizer.data_parallel_group_gloo) + gbuf_local_numel = gbuf_world_numel // optimizer.data_parallel_group_gloo.size() + + world_tensors = {} + if data_parallel_rank == 0: + world_tensors = { + key: torch.zeros((buffer_numel_unpadded,), dtype=torch.float32, device="cpu") + for key in ("param", "exp_avg", "exp_avg_sq") + } + world_tensors["numel_unpadded"] = buffer_numel_unpadded + + for key, send_tensor in local_shards.items(): + recv_tensors = [ + torch.zeros((gbuf_local_numel,), dtype=torch.float32, device="cpu") + for _ in range(data_parallel_world_size) + ] if data_parallel_rank == 0 else None + + torch.distributed.gather( + send_tensor, recv_tensors, data_parallel_global_ranks[0], data_parallel_group_gloo + ) + + # Concatenate. + if data_parallel_rank == 0: + recv_tensors_concatenated = torch.cat(recv_tensors) + end = gbuf_world_numel_unpadded + world_tensors[key][0:end].copy_(recv_tensors_concatenated[:end]) + + dtype_state = {} + dtype_state[next(iter(optimizer.gbuf_ranges[0]))] = world_tensors + state = {"buckets_coalesced": True} + state[0] = dtype_state + + return state + + def async_gather_save_ckpt(self, iteration: int, local_shards, gbuf_world_numel, + gbuf_world_numel_unpadded, buffer_numel_unpadded): + gathered_state = self.cpu_gather_all_shards(local_shards, gbuf_world_numel, + gbuf_world_numel_unpadded, buffer_numel_unpadded) + if torch.distributed.get_rank(self.optimizer.data_parallel_group) == 0: + self.async_save_distribute_ckpt(gathered_state, iteration) + + def async_cpu_gather_shards(self, iteration: int): + (local_shards, gbuf_world_numel, + gbuf_world_numel_unpadded, buffer_numel_unpadded) = self.construct_local_shards() + + thread_save = threading.Thread(target=self.async_gather_save_ckpt, + args=(iteration, local_shards, gbuf_world_numel, + gbuf_world_numel_unpadded, buffer_numel_unpadded, )) + thread_save.start() + def save_ckpt(self, iteration: int): - self.sharding_manager.enter_train_mode() - self.save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler, - self.num_floating_point_operations_so_far) - self.sharding_manager.exit_train_mode() + rank_id = torch.distributed.get_rank(self.optimizer.data_parallel_group) + if rank_id == 0: + thread_save = threading.Thread(target=self.async_save_model_rng_ckpt, args=(iteration, )) + thread_save.start() + + if self.rl_config.ras_flag is True: + self.async_cpu_gather_shards(iteration) + return + + distribute_dict = self.optimizer.get_parameter_state_dp_zero() + # only rank 0 save ckpt + if rank_id == 0: + thread_save = threading.Thread(target=self.async_save_distribute_ckpt, + args=(distribute_dict, iteration, )) + thread_save.start() @mstx_timer_decorator def generate_sequences(self):