diff --git a/cli/train_grpo.py b/cli/train_grpo.py index fb687e1756654202f31ad9334c4521bd6b124b69..944bb2a24a52e69d622893731cbc63885553c5d5 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -182,6 +182,7 @@ def train(config): train_iters=actor_config.train_iters, save_interval=actor_config.save_interval, dataset_additional_keys=actor_config.dataset_additional_keys, + transfer_dock_profiler_config=profiler_config["transfer_dock"], **rl_config.dict() ) @@ -227,6 +228,10 @@ def parse_training_config(config: Dict): config.get("profiler_config", {}).get("integrated", {}), role="integrated" ), + "transfer_dock": ProfilerConfig( + config.get("profiler_config", {}).get("transfer_dock", {}), + role="transfer_dock" + ), }) actor_config.max_prompt_length = rl_config.max_prompt_length diff --git a/docs/features/profiler.md b/docs/features/profiler.md index e1c6b9516876091eaa43921b21298035fffb2fd7..39ff9fea8c9828da378e0a71daa446e9eff1edf0 100644 --- a/docs/features/profiler.md +++ b/docs/features/profiler.md @@ -34,6 +34,7 @@ profiler_config: | 参数 | 说明 | 可选值 | |------|------|--------| +| integrated | 采集性能数据的ray进程 | 可配置integrated(integrated worker进程性能数据)、transfer_dock(TD进程性能数据) | | profile | 性能分析开关 | true/false,所有性能数据采集均依赖该开关开启 | | mstx | 轻量化打点采集开关 | true/false,默认值false,启用/关闭轻量化打点采集,需要查看轻量化打点性能数据时需开启 | | stage | 性能数据采集阶段 | all(采集所有阶段性能数据)、actor_generate(采集actor模型生成阶段性能数据)、actor_compute_log_prob(采集actor模型计算log概率阶段性能数据)、reference_compute_log_prob(采集reference参考模型计算log概率阶段性能数据)、actor_update(采集模型更新阶段性能数据) | @@ -74,7 +75,7 @@ GRPO算法涉及多个worker模块交互,包含训练、推理等流程。在 profile_with_npu: true ``` -- **适用场景**: 轻量级采集包含自定义打点和所有通信算子的内置打点,MindSpeed-RL已集成所有worker的关键计算函数、dispatch_transfer_dock_data、resharding等关键函数的自定义打点。如需查看某代码片段在timeline中的执行耗时,可通过以下两种方式在MindSpeed-RL中添加自定义打点: +- **适用场景**: 轻量级采集包含自定义打点和所有通信算子的内置打点,MindSpeed-RL已集成所有worker的关键计算函数、dispatch_transfer_dock_data、resharding等关键函数的自定义打点,同时也包括TD进程的数据处理函数执行耗时。如需查看某代码片段在timeline中的执行耗时,可通过以下两种方式在MindSpeed-RL中添加自定义打点: ```python # 方式一:使用装饰器装饰函数 diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index a38d0bb9d42884b7a870c62280cdcc51dab7ddca..89ebfc4b45c9361915e609fccb3545f6802fad31 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -61,6 +61,7 @@ class RayGRPOTrainer(RayBaseTrainer): n_samples_per_prompt: int = 1, tokenizer: BaseTokenizer = None, dataset_additional_keys: List[str] = None, + transfer_dock_profiler_config: ProfilerConfig = None, blocking: bool = False, guarantee_order: bool = False, num_cpus_for_local_task: int = 1, @@ -90,13 +91,15 @@ class RayGRPOTrainer(RayBaseTrainer): self.transfer_dock = None self.metrics = Metric() + self.transfer_dock_profiler_config = transfer_dock_profiler_config self.transfer_dock_init() self.kwargs = kwargs self.set_actor_log_prob_skip_flag() def transfer_dock_init(self): self.transfer_dock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, - self.metrics, addition_columns=self.dataset_additional_keys) + self.metrics, addition_columns=self.dataset_additional_keys, + profiler_config=self.transfer_dock_profiler_config) self.actor_worker.sync_init_transfer_dock(self.transfer_dock) self.ref_worker.sync_init_transfer_dock(self.transfer_dock) for reward in self.reward_list: @@ -126,7 +129,7 @@ class RayGRPOTrainer(RayBaseTrainer): logger.info('sync start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters)) else: logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters)) - + ray.get(self.transfer_dock.prof_step.remote()) while iteration < self.train_iters: ray.get(self.transfer_dock.clear.remote()) @@ -190,6 +193,7 @@ class RayGRPOTrainer(RayBaseTrainer): self.wandb.log_metrics(metrics.metric, iteration) if iteration % self.save_interval == 0: self.save_checkpoint(iteration) + ray.get(self.transfer_dock.prof_step.remote()) logger.info('after grpo training is done') ray.shutdown() diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index c9a324d6f4f018da9c35de95b5372f9c64d1bf5d..82f4e951d59e8b4894d6e15a3d24da863f88fb38 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -4,6 +4,7 @@ import copy import time import threading from abc import ABC +from queue import Queue from typing import List, Dict, Union, Optional from operator import itemgetter @@ -12,10 +13,45 @@ import torch from torch import Tensor from torch.nn import functional as F from torch.nn.utils.rnn import pad_sequence +from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig from mindspeed_rl.utils.loggers import Loggers +from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step logger = Loggers("transfer_dock") +class ThreadSafeProfiler: + def __init__(self, profiler_config): + self.command_queue = Queue() + self.profiler = None + self.profiler_config = profiler_config + self.profiler_config.profile_level = "level_none" + self.profiler_config.mstx = True + self.profiler_config.stage = "all" + self.profiler_config.profile_with_cpu = False + self.profiler_config.profile_with_npu = True + + # 启动专用处理线程,因为profiler的start和stop需要在同一线程,而transfer_dock的方法调用是异步的,在线程池中随机一个线程中执行 + self.profiler_thread = threading.Thread( + target=self._profiler_loop, + daemon=True + ) + self.profiler_thread.start() + + def _profiler_loop(self): + """专用线程执行循环""" + while True: + cmd = self.command_queue.get() + if cmd == "start": + self.profiler = profiler_start(self.profiler_config, "transfer_dock") + elif cmd == "step": + profiler_step(self.profiler) + + def start(self): + self.command_queue.put("start") + + def step(self): + self.command_queue.put("step") + class TimeoutException(Exception): """Custom Timeout Exception""" @@ -68,6 +104,7 @@ class TransferDock(ABC): self.timeout = timeout if timeout is not None else 300 self.timeout_interval = timeout_interval if timeout_interval is not None else 5 + @mstx_timer_decorator def _put( self, experience_columns: List[str], @@ -121,6 +158,7 @@ class TransferDock(ABC): self.experience_data[single_column][index] = experience[column_idx][i] self.experience_data_status[single_column][index] = 1 + @mstx_timer_decorator def _get(self, experience_columns: List[str], indexes: List[int]): """Get data based on row and column numbers. @@ -165,6 +203,7 @@ class TransferDock(ABC): return experience + @mstx_timer_decorator def _wait_for_data(self, single_column: str, indexes: List[int]): """Wait for data in which column and row to be ready. @@ -196,6 +235,7 @@ class TransferDock(ABC): itemgetter(*indexes)(self.experience_data_status[single_column]) ) == len(indexes) + @mstx_timer_decorator def _clear_experience_data_and_status(self, indexes=None): """Clear data and data status in TransferDock. @@ -259,6 +299,7 @@ class GRPOTransferDock(TransferDock): addition_consumers: Union[List[str], None] = None, timeout: Union[int, None] = None, timeout_interval: Union[int, None] = None, + profiler_config: ProfilerConfig = None, ) -> None: """GRPOTransferDock initialize. @@ -324,6 +365,8 @@ class GRPOTransferDock(TransferDock): for key in self.experience_consumers } self.metrics = metrics + self.profiler = None + self.profiler_config = profiler_config def get_metrics(self): return self.metrics @@ -331,6 +374,7 @@ class GRPOTransferDock(TransferDock): def update_metrics(self, key="", value=None, cumulate=False): self.metrics.update(key, value, cumulate=cumulate) + @mstx_timer_decorator def get_experience( self, consumer: str, @@ -402,6 +446,7 @@ class GRPOTransferDock(TransferDock): experience_batch[experience_column] = experience[i] return experience_batch, indexes + @mstx_timer_decorator def put_experience( self, data_dict: Dict[str, Union[Tensor, List[Tensor]]], @@ -424,6 +469,7 @@ class GRPOTransferDock(TransferDock): experience_columns, experience = trans_input_to_experience(data_dict) self._put(experience_columns, experience, indexes) + @mstx_timer_decorator def put_prompts_experience( self, batch: Dict[str, Tensor], dataset_additional_keys: List[str] = None ): @@ -466,6 +512,7 @@ class GRPOTransferDock(TransferDock): self._put(experience_columns, experience, indexes) + @mstx_timer_decorator def _sample_ready_index( self, consumer: str, @@ -506,6 +553,7 @@ class GRPOTransferDock(TransferDock): return sampled_indexes + @mstx_timer_decorator def _sample_ready_index_n_samples( self, consumer: str, @@ -590,6 +638,7 @@ class GRPOTransferDock(TransferDock): """ return self.experience_consumer_status[consumer].sum() == self.max_len + @mstx_timer_decorator def clear(self): """Reset consumer status.Clear data and data status in GRPOTransferDock. @@ -630,6 +679,16 @@ class GRPOTransferDock(TransferDock): return sampled_indexes + def prof_step(self): + """transfer dock profiler""" + if not self.profiler_config.profile: + return + if not self.profiler: + self.profiler = ThreadSafeProfiler(self.profiler_config) + self.profiler.start() + else: + self.profiler.step() + def pad_experience( experience_batch: Dict[str, List[Tensor]],