diff --git a/configs/envs/runtime_env.yaml b/configs/envs/runtime_env.yaml index 7ac3d617a32ee76ae3008875a1d8253408323c66..24a5e1a64026cdf67b585f51be26abef133ace6a 100644 --- a/configs/envs/runtime_env.yaml +++ b/configs/envs/runtime_env.yaml @@ -1,22 +1,43 @@ env_vars: + # 禁用 Ray 对 ASCEND_RT_VISIBLE_DEVICES 的自动设置 RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES: 'true' + #设置tokennizers是否支持并行 TOKENIZERS_PARALLELISM: 'true' + #设置 NCCL Debug日志级别 NCCL_DEBUG: 'WARN' + #允许 NPU 内存分配器动态扩展已分配的内存段 PYTORCH_NPU_ALLOC_CONF: 'expandable_segments:True' + #设置 HCCL 连接超时时间 HCCL_CONNECT_TIMEOUT: '1800' + #设置 HCCL 执行超时时间 HCCL_EXEC_TIMEOUT: '3600' + #设置 HCCL 通信端口 HCCL_IF_BASE_PORT: '48000' + #设置设备最大连接数 CUDA_DEVICE_MAX_CONNECTIONS: '1' + #设置 HYDRA 是否输出完整错误日志 HYDRA_FULL_ERROR: '1' - VLLM_DP_SIZE: '1' + + # vLLM数据并行度(Data Parallelism)大小,控制数据分片数量,MOE模型建议和EP一致,稠密模型设置为1 + VLLM_DP_SIZE: '4' + # HCCL通信层单次传输的最大缓冲区大小(单位MB),影响跨设备通信效率 HCCL_BUFFSIZE: '256' + # 使用vLLM的V1 engine API(v1接口),兼容性选项 VLLM_USE_V1: '1' + # 指定使用的vLLM版本号 VLLM_VERSION: '0.9.0' + # 启用昇腾torchair图模式优化(1=启用),提升执行效率 VLLM_ENABLE_GRAPH_MODE: '0' + # 启用vLLM的通算融合算子调度策略 VLLM_ENABLE_MC2: '0' + # HCCL算子扩展模式(AIV=AI向量模式),启用高级通信优化 HCCL_OP_EXPANSION_MODE: "AIV" + # 使能vLLM TOPK性能优化 VLLM_ENABLE_TOPK_OPTIMZE: "1" +#指定 GLOO 框架通信网卡 # GLOO_SOCKET_IFNAME: "Your SOCKET IFNAME" +#指定 TP 相关通信网卡 # TP_SOCKET_IFNAME: "Your SOCKET IFNAME" +#指定 HCCL 通信网卡 # HCCL_SOCKET_IFNAME: "Your SOCKET IFNAME" \ No newline at end of file diff --git a/docs/algorithms/grpo.md b/docs/algorithms/grpo.md index e69b0216e6d737fe8c0c7b7fa11c85a6e25b6e0a..ff779c7fb5e9c3bdcf147a424760228e2470ae43 100644 --- a/docs/algorithms/grpo.md +++ b/docs/algorithms/grpo.md @@ -97,6 +97,93 @@ bash examples/grpo/grpo_trainer_qwen25_7b.sh ***注意:所有节点的代码、权重、数据等路径的层级要保持一致,且启动ray的时候都位于MindSpeed-RL目录下*** +## 日志打点指标说明 + +**时间相关指标说明** + +| 指标 | 说明 | +| ------------------------------------ | -------------------------------------------------------- | +| `timing/all` | 一次迭代总时间 | +| `timing/update` | 一次迭代中actor model进行update耗时 | +| `timing/rollout` | 一次迭代中actor model进行rollout耗时 | +| `timing/old_log_p` | 一次迭代中actor model计算log p耗时 | +| `timing/reference_model` | 一次迭代中reference model计算log p耗时 | +| `timing/resharding_to_train` | 权重转到训练mode耗时 | +| `timing/resharding_to_infer` | 权重转到推理mode耗时 | +| `timing/adv` | 计算advantages耗时 | +| `timing/non_overlap_reference_model` | reference model计算log_p耗时的未被掩盖时间 | +| `timing/non_overlap_rule_reward` | rule_reward耗时的未被掩盖时间 | +| `timing/non_overlap_reward_model` | reward_model耗时的未被掩盖时间 | +| `timing/non_overlap_adv` | advantages计算耗时的未被掩盖时间 | +| `timing/overlap_referece_model` | actor model计算log p和reference model计算log p的时间交集 | +| `timing/overlap_update` | actor model计算log p和进行update的时间交集 | +| `timing/rule_reward` | rule reward打分耗时 | +| `timing/reward_model` | reward model打分耗时 | +| `timing/ref_onload` | reference model计算logp过程中,onload耗时 | +| `timing/ref_offload` | reference model计算logp过程中,offload耗时 | + +* 全共卡方案下总时间计算方式 + +`timing/all` >= `timing/rollout` +`timing/old_log_p` + `timing/update` + `timing/reference` + `timing/reshard_to_train` + `timing/reshard_to_infer` - `timing/overlap_referece_model` - `timing/overlap_update` + `max(timing/non_overlap_rule_reward, timing/non_overlap_reference_model)` + +如果在rl_config中将td_timer_enabled参数设置为True,可以统计transfer dock存取数据耗时,但transfer dock相关的时间统计本身会带来额外的计算时间,非必要不开启 + +| 指标 | 说明 | +| --------------------------- | ------------------------------------------------------ | +| `timing/rollout_dispatch` | actor model进行rollout过程中,transfer dock取数据耗时 | +| `timing/rollout_collect` | actor model进行rollout过程中,transfer dock存数据耗时 | +| `timing/old_log_p_dispatch` | actor model计算log_p过程中,transfer dock取数据耗时 | +| `timing/old_log_p_collect` | actor model计算log_p过程中,transfer dock存数据耗时 | +| `timing/update_dispatch` | actor model update过程中,transfer dock取数据耗时 | +| `timing/ref_log_p_dispatch` | reference model计算logp过程中,transfer dock取数据耗时 | +| `timing/ref_log_p_collect` | reference model计算logp过程中,transfer dock存数据耗时 | + +**其他指标** + +| 指标 | 说明 | +| --------------------------------------- | ------------------------------------------------------------ | +| `actor/entropy` | 策略熵,表示策略的随机性或探索能力 | +| `actor/kl_loss` | kl散度,衡量当前策略与参考策略(如旧策略或参考模型)之间的偏离程度 | +| `actor/pg_loss` | pg_loss,基于优势函数的策略梯度目标函数值,表示当前策略对提升奖励的学习能力。 | +| `actor/pg_clipfrac` | GRPO中裁剪机制生效的比例,反映了策略更新幅度的稳定性 | +| `actor/ppo_kl` | PPO算法的实际 KL 散度 | +| `grad_norm` | 梯度范数,表示当前反向传播中参数梯度的整体幅度 | +| `grpo/{verifier_function}_rewards/mean` | 规则奖励打分的平均总奖励值 | +| `grpo/lr` | 学习率,优化器当前使用的学习率 | +| `grpo/score/mean` | 开启奖励模型时的reward均值 | +| `grpo/score/max` | 奖励模型及规则奖励对同一个样本的reward最大值 | +| `grpo/score/min ` | 奖励模型及规则奖励对同一个样本的reward最小值 | +| `grpo/rewards/mean` | 规则奖励的reward均值;奖励模型对样本的reward经过归一化后的均值 | +| `grpo/rewards/max` | 规则奖励的reward最大值;奖励模型对样本的reward经过归一化后的最大值 | +| `grpo/rewards/min` | 规则奖励的reward最小值;奖励模型对样本的reward经过归一化后的最小值 | +| `response_length/mean` | 平均生成长度,模型生成回复(response)的平均 token 数 | +| `response_length/min` | 最短生成长度,当前 batch 中生成最短的 response 长度 | +| `response_length/max` | 最长生成长度,当前 batch 中生成最长的 response 长度 | +| `prompt_length/mean` | 平均输入长度,输入 prompt 的平均长度 | +| `prompt_length/max` | 最长输入长度,当前 batch 中最长的 prompt长度 | +| `prompt_length/min` | 最短输入长度,当前 batch 中最长的 prompt长度 | +| `e2e_tps` | 端到端的tokens/p/s指标 | +| `update_tps` | 训练的tokens/p/s指标 | +| `vllm_tps` | 推理的tokens/p/s指标 | + +* e2e_tps计算方式 + +$$ +(\text{response_length_mean} + \text{prompt_length_mean}) \times \text{global_batch_size} \times \text{n_samples_per_prompt} / \text{world_size} \ / \text{time_all} +$$ + +* update_tps计算方式 + +$$ +(\text{response_length_mean} + \text{prompt_length_mean}) \times \text{global_batch_size} \times \text{n_samples_per_prompt} / \text{world_size} \ / \text{time_update} +$$ + +* vllm_tps计算方式 + +$$ +(\text{response_length_mean} + \text{prompt_length_mean}) \times \text{global_batch_size} \times \text{n_samples_per_prompt} / \text{world_size} \ / \text{time_rollout} +$$ + ## 断点续训 进行断点续训时,需要注意配置以下参数: ```yaml @@ -112,3 +199,10 @@ rl_config: ref_model_load_path: ./Qwen2.5-7B-tp4 <------- 断点续训时,应在 ref_model_load_path 中配置原始模型权重路径,供 reference model 加载 ``` +## 实践效果 + +当前已成功复现DeepSeekR1-ZERO训练流程以及训练效果,详细的复现流程以及效果图展示在以下文档: + +[DeepSeekR1-ZERO-Qwen2.5-7B](../solutions/r1_zero_qwen25_7b.md) + +[DeepSeekR1-ZERO-Qwen2.5-32B](../solutions/r1_zero_qwen25_32b.md) diff --git a/docs/features/integrated_worker.md b/docs/features/integrated_worker.md index 93909ed57ea8db1b6c7f224ab37d1e9b413a8e07..141401c5a2e72fd6f2f44d4205ca2e79b12ba4a1 100644 --- a/docs/features/integrated_worker.md +++ b/docs/features/integrated_worker.md @@ -64,7 +64,7 @@ rl_config: ![sharding_process](../../sources/images/integrated_worker/sharding_process.jpg) 当前框架会自动启用训推共卡式 Actor,在配置文件中,可以对共卡情况下的训练态和推理态模型的切分策略进行分别配置,并设定在推理时是否需要对训练相关权重、梯度和优化器进行卸载。 -以 `grpo_trainer_qwen25_7b.yaml` 为例, +以 `grpo_qwen25_7b_A3.yaml` 为例, ```yaml actor_config: @@ -82,4 +82,3 @@ generate_config: offload_train_param: true # 设置为 true 可以使能在推理时卸载训练态权重 ``` - diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index e2d14038539190b6fc0ddf6306ce015f80407358..2ac7d330657c6eccb07d7f2ff0f41b104e83317e 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -93,6 +93,7 @@ class RLConfig(BaseConfig): self.wandb_save_dir = "" self.blocking = True self.guarantee_order = False + self.td_timer_enabled = False self.num_cpus_for_local_task = 1 self.num_cpus_for_placement_group = 8 self.use_integrated_worker = True diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index a38d0bb9d42884b7a870c62280cdcc51dab7ddca..38fc3e90859ddb90ad9e3b8f5e5cb28fe76cb7d6 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -176,11 +176,13 @@ class RayGRPOTrainer(RayBaseTrainer): metrics_result = metrics_post_processing(metrics_result) metrics_result = metrics_sort(metrics_result, all_timer.last) tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, all_timer.last) - vllm_throughput = compute_vllm_throughput(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, metrics_result["timing/rollout"]) + update_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, metrics_result["timing/update"]) + vllm_tps = compute_vllm_throughput(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, metrics_result["timing/rollout"]) metrics.update(value=metrics_result) metrics.update(value=grpo_data_metrics) - metrics.update("tokens/p/s", tps) - metrics.update("vllm_throughput", vllm_throughput) + metrics.update("e2e_tps", tps) + metrics.update("update_tps", update_tps) + metrics.update("vllm_tps", vllm_tps) iteration += 1 logger.info(metrics.metric, iteration, self.train_iters) if self.tensorboard is not None: diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 44d5a631aab266aee2432555c2dab056ff456074..7f422469efe4a638d0af9128a907874e16aed35d 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -4,6 +4,7 @@ import os import sys import json +from contextlib import contextmanager import time import math @@ -11,6 +12,7 @@ import random from functools import wraps from typing import Dict, List +import ray import omegaconf import numpy as np import torch @@ -171,6 +173,21 @@ def get_batch_metrices_mean(metrics_list: List[Dict]) -> Dict[str, Tensor]: return metrics_mean +@contextmanager +def td_timer(name, td_timer_enabled, td, cumulate=True): + if td_timer_enabled: + start_time = time.time() + yield + end_time = time.time() + ray.get(td.update_metrics.remote( + f"timing/{name}", + value=[round(end_time, 4), round(start_time, 4)], + cumulate=cumulate + )) + else: + yield + + def metrics_post_processing(metrics) -> Dict[str, Tensor]: """ Calculate the mean of each metric across a list of metric dictionaries. @@ -207,19 +224,29 @@ def metrics_post_processing(metrics) -> Dict[str, Tensor]: def metrics_sort(metrics, time_all) -> Dict[str, Tensor]: - custom_order = ['timing/all', 'timing/update', 'timing/resharding_to_train', 'timing/rollout', 'timing/resharding_to_infer', 'timing/old_log_p', 'timing/reference_model', 'timing/non_overlap_reference_model'] - special_keys = ['timing/non_overlap_rule_reward', 'timing/non_overlap_reward_model', 'timing/rule_reward', 'timing/reward_model', 'timing/adv', 'timing/non_overlap_adv', 'timing/onload', 'timing/offload'] - old_log_p_end_time = metrics.pop('end_time/old_log_p', None) - end_adv_time = metrics.pop('end_time/end_adv_time', None) - + custom_order = ['timing/all', 'timing/update', 'timing/rollout', 'timing/old_log_p', 'timing/reference_model', 'timing/resharding_to_infer', 'timing/resharding_to_train', 'timing/adv', 'timing/non_overlap_reference_model'] + special_keys = ['timing/non_overlap_rule_reward', 'timing/non_overlap_reward_model', 'timing/non_overlap_adv', 'timing/overlap_reference_model', 'timing/overlap_update', 'timing/rule_reward', 'timing/reward_model', \ + 'timing/rollout_dispatch', 'timing/rollout_collect', 'timing/old_log_p_dispatch', 'timing/old_log_p_collect', 'timing/update_dispatch', 'timing/ref_log_p_dispatch', 'timing/ref_log_p_collect', \ + 'timing/ref_onload', 'timing/ref_offload'] + reference_start_time = metrics.pop('start_time/reference_model', None) reference_end_time = metrics.pop('end_time/reference', None) + old_log_p_start_time = metrics.pop('start_time/old_log_p', None) + old_log_p_end_time = metrics.pop('end_time/old_log_p', None) + end_adv_time = metrics.pop('end_time/end_adv_time', None) + update_start_time = metrics.pop('start_time/update', None) if old_log_p_end_time is None: old_log_p_end_time = reference_end_time custom_order.remove('timing/old_log_p') non_overlap_reference_model_time = max(reference_end_time - max(old_log_p_end_time, reference_start_time), 0) non_overlap_adv_time = max(max(old_log_p_end_time, end_adv_time) - old_log_p_end_time, 0) + + if old_log_p_start_time is not None and reference_end_time > old_log_p_start_time: + metrics["timing/overlap_reference_model"] = reference_end_time - old_log_p_start_time + + if old_log_p_end_time > update_start_time: + metrics["timing/overlap_update"] = old_log_p_end_time - update_start_time if "timing/rule_reward" in metrics.keys(): reward_start_time = metrics.pop('start_time/rule_reward', None) @@ -250,7 +277,6 @@ def metrics_sort(metrics, time_all) -> Dict[str, Tensor]: def compute_tps(compute_kwargs, metrics_result, gbs, n_samples, time_all): - actor_resource = compute_kwargs.get('actor_resource', {}) reference_resource = compute_kwargs.get('reference_resource', {}) reward_resource = compute_kwargs.get('reward_resource', None) @@ -276,7 +302,7 @@ def compute_vllm_throughput(compute_kwargs, metrics_result, gbs, n_samples, time reward_npus = reward_resource.get('num_npus', 0) if reward_resource is not None else 0 world_size = actor_npus + reference_npus + reward_npus if not actor_resource_only else actor_npus - vllm_throughput = metrics_result['response_length/mean'] * gbs * n_samples / world_size / time_rollout + vllm_throughput = (metrics_result['response_length/mean'] + metrics_result['prompt_length/mean']) * gbs * n_samples / world_size / time_rollout return vllm_throughput diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 0a81054a1e28ac15c84fca6363fc6ded3ef7b654..72f4e0c743ef452d230c8f3391e702d7b15d6a5f 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -21,7 +21,7 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.utils.utils import MsProbe from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronShardingManager, MegatronOffLoader -from mindspeed_rl.utils.utils import num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator, profiler_start, profiler_step +from mindspeed_rl.utils.utils import num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator, profiler_start, profiler_step, td_timer from mindspeed_rl.utils.pad_process import remove_padding_and_split_to_list, truncate_rows @@ -165,16 +165,24 @@ class ActorHybridWorkerBase(BaseWorker): start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: - batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, - experience_columns, - experience_count, - self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop( - 0) if self.rl_config.guarantee_order else None, - get_n_samples=False) + with td_timer("update_dispatch", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, + experience_columns, + experience_count, + self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, + get_n_samples=False) if not start_time_defined: start_time = time.time() start_time_defined = True + ray.get( + self.td.update_metrics.remote( + "start_time/update", + value=[round(start_time, 4)], + cumulate=True + ) + ) if batch_data and index: metrics = self.actor_hybrid.update_actor(batch_data, kl_ctrl) @@ -235,14 +243,15 @@ class ActorHybridWorkerBase(BaseWorker): start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0: - batch_data, index = self.dispatch_transfer_dock_data( - experience_consumer_stage, - experience_columns, - experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, - use_vllm=True - ) + with td_timer("rollout_dispatch", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + batch_data, index = self.dispatch_transfer_dock_data( + experience_consumer_stage, + experience_columns, + experience_count, + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None, + use_vllm=True + ) if not start_time_defined: start_time = time.time() start_time_defined = True @@ -274,7 +283,8 @@ class ActorHybridWorkerBase(BaseWorker): 'input_ids': input_ids_list, 'response_length': responses_length } - self.collect_transfer_dock_data(outputs, index, use_vllm=True) + with td_timer("rollout_collect", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + self.collect_transfer_dock_data(outputs, index, use_vllm=True) end_time = time.time() MsProbe.save_data({"responses": responses, "prompts": prompts}) @@ -303,7 +313,6 @@ class ActorHybridWorkerBase(BaseWorker): @mstx_timer_decorator def compute_log_prob(self): self.sharding_manager.enter_forward_mode() - experience_consumer_stage = 'actor_log_prob' experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] experience_count = self.rl_config.actor_logprob_dispatch_size @@ -315,15 +324,23 @@ class ActorHybridWorkerBase(BaseWorker): MsProbe.debugger_start(self.model[0], tag='actor_compute_log_prob') start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: - batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, - experience_columns, - experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop( - 0) if self.rl_config.guarantee_order else None, - get_n_samples=False) + with td_timer("old_log_p_dispatch", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, + experience_columns, + experience_count, + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, + get_n_samples=False) if not start_time_defined: start_time = time.time() + ray.get( + self.td.update_metrics.remote( + "start_time/old_log_p", + value=[round(start_time, 4)], + cumulate=True + ) + ) start_time_defined = True if batch_data and index: output, batch = self.actor_hybrid.compute_log_prob(batch_data) @@ -333,7 +350,8 @@ class ActorHybridWorkerBase(BaseWorker): log_probs = log_probs.to(torch.float32) log_probs = truncate_rows(log_probs, batch['response_length']) output = {'old_log_prob': log_probs} - self.collect_transfer_dock_data(output, index) + with td_timer("old_log_p_collect", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + self.collect_transfer_dock_data(output, index) end_time = time.time() ray.get( self.td.update_metrics.remote( diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index 5a4de8f8137a438cee91b8006ea9ca6f7f86af8b..e276aacb50a48a2f76bd39ef403af34b3e91e90f 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -14,7 +14,7 @@ from mindspeed_rl.config_cls.generate_config import GenerateConfig from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronOffLoader -from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step +from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step, td_timer from mindspeed_rl.utils.utils import MsProbe from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorkerBase @@ -107,16 +107,9 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB @mstx_timer_decorator def compute_ref_log_prob(self): - start_onload_time = time.time() - self.ref_manager.onload_param() - end_onload_time = time.time() - ray.get( - self.td.update_metrics.remote( - "timing/onload", - value=[round(end_onload_time, 4), round(start_onload_time, 4)], - cumulate=True - ) - ) + with td_timer("ref_onload", td_timer_enabled=True, td=self.td, cumulate=True): + self.ref_manager.onload_param() + compute_log_prob_profiler = profiler_start(self.profiler_config, role="reference_compute_log_prob", profiler_iteration=self.prof_iteration) MsProbe.debugger_start(model=self.ref_model, tag="reference_compute_log_prob") @@ -131,16 +124,8 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB ReferenceWorkerBase.compute_ref_log_prob(self) profiler_step(compute_log_prob_profiler) MsProbe.debugger_stop("reference_compute_log_prob") - start_offload_time = time.time() - self.ref_manager.offload_param() - end_offload_time = time.time() - ray.get( - self.td.update_metrics.remote( - "timing/offload", - value=[round(end_offload_time, 4), round(start_offload_time, 4)], - cumulate=True - ) - ) + with td_timer("ref_offload", td_timer_enabled=True, td=self.td, cumulate=True): + self.ref_manager.offload_param() def compute_log_prob(self): if self.actor_forward_micro_batch_size is not None: diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 942582bb34aab712d7e8be5e2e67206b03fbce74..299c9ea607c7dd24abe15a000f701ca61e66c236 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -16,7 +16,7 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.utils.compute import get_parallel_state from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank -from mindspeed_rl.utils.utils import mstx_timer_decorator +from mindspeed_rl.utils.utils import mstx_timer_decorator, td_timer class ReferenceWorkerBase(BaseWorker): @@ -94,13 +94,14 @@ class ReferenceWorkerBase(BaseWorker): start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: - batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, - experience_columns, - experience_count, - tp_size=self.megatron_config.tensor_model_parallel_size, - indexes=sorted_indexes.pop( - 0) if self.rl_config.guarantee_order else None, - get_n_samples=False) + with td_timer("ref_log_p_dispatch", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, + experience_columns, + experience_count, + tp_size=self.megatron_config.tensor_model_parallel_size, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None, + get_n_samples=False) if not start_time_defined: start_time = time.time() @@ -121,7 +122,8 @@ class ReferenceWorkerBase(BaseWorker): log_probs = log_probs.to(torch.float32) log_probs = truncate_rows(log_probs, batch['response_length']) output = {'ref_log_prob': log_probs} - self.collect_transfer_dock_data(output, index) + with td_timer("ref_log_p_collect", self.rl_config.td_timer_enabled, td=self.td, cumulate=True): + self.collect_transfer_dock_data(output, index) end_time = time.time() ray.get( self.td.update_metrics.remote(