From 209ef0b4ee790b43813ffc6f61e907c39e6d8b16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9F=A9=E6=8C=AF=E5=AE=87?= Date: Thu, 19 Jun 2025 16:40:01 +0800 Subject: [PATCH] support for distributed TransferDock --- mindspeed_rl/trainer/grpo_trainer_hybrid.py | 170 ++- mindspeed_rl/trainer/utils/__init__.py | 5 +- mindspeed_rl/trainer/utils/compute_utils.py | 28 +- mindspeed_rl/trainer/utils/transfer_dock.py | 1086 +++++++++---------- mindspeed_rl/utils/metrics.py | 40 + mindspeed_rl/workers/actor_hybrid_worker.py | 49 +- mindspeed_rl/workers/base_worker.py | 189 +--- mindspeed_rl/workers/integrated_worker.py | 4 +- mindspeed_rl/workers/reference_woker.py | 20 +- mindspeed_rl/workers/reward_woker.py | 9 +- mindspeed_rl/workers/rule_reward.py | 20 +- mindspeed_rl/workers/scheduler/launcher.py | 15 +- 12 files changed, 791 insertions(+), 844 deletions(-) diff --git a/mindspeed_rl/trainer/grpo_trainer_hybrid.py b/mindspeed_rl/trainer/grpo_trainer_hybrid.py index ce82264d..15dcedad 100644 --- a/mindspeed_rl/trainer/grpo_trainer_hybrid.py +++ b/mindspeed_rl/trainer/grpo_trainer_hybrid.py @@ -12,12 +12,12 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.rule_reward import RuleReward from mindspeed_rl.trainer.base import RayBaseTrainer from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig -from mindspeed_rl.trainer.utils import GRPOTransferDock, MMGRPOTransferDock +from mindspeed_rl.trainer.utils.transfer_dock import put_prompts, TransferDockManager, DistributedStorageNode from mindspeed_rl.trainer.utils.compute_utils import compute_advantage, compute_grpo_data_metrics from mindspeed_rl.workers.scheduler.launcher import RayActorGroup from mindspeed_rl.utils.loggers import Loggers -from mindspeed_rl.utils.metrics import Metric -from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort, is_multimodal +from mindspeed_rl.utils.metrics import Metric, ActorMetric +from mindspeed_rl.utils.utils import metrics_post_processing, compute_tps, metrics_sort class RayGRPOTrainer(RayBaseTrainer): @@ -63,7 +63,7 @@ class RayGRPOTrainer(RayBaseTrainer): dataset_additional_keys: List[str] = None, blocking: bool = False, guarantee_order: bool = False, - num_cpus_for_local_task: int = 1, + num_cpus_for_local_task: float = 0.1, **kwargs ): super().__init__( @@ -88,26 +88,122 @@ class RayGRPOTrainer(RayBaseTrainer): **kwargs ) - self.transfer_dock = None - self.mm_transfer_dock = None - self.metrics = Metric() + self.consumers_columns = { + 'actor_rollout': ['prompts', 'prompt_length'], + 'actor_log_prob': ['input_ids', 'responses', 'response_length', 'prompt_length'], + 'ref_log_prob': ['input_ids', 'responses', 'response_length', 'prompt_length'], + 'actor_train': ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', + 'input_ids', 'response_length', 'prompt_length'], + 'compute_advantage': ["responses", "token_level_rewards", "response_length"], + 'rule_reward': ['prompts', 'responses', 'response_length'], + 'reward_scores': ['input_ids', 'prompt_length', "responses", "response_length"], + 'grpo_metrics': ["responses", "token_level_rewards", "response_length", + "rm_scores", "advantages", "returns", "prompt_length"], + } + if self.dataset_additional_keys: + self.consumers_columns['rule_reward'] += self.dataset_additional_keys + self.consumers_columns['reward_scores'] += self.dataset_additional_keys + + self.TD_managers_map_consumers = { + 'ActorHybridWorker': ['actor_rollout', 'actor_log_prob', 'actor_train'], + 'ActorWorker': ['actor_log_prob', 'actor_train'], + 'ActorFwdWorker': ['actor_log_prob'], + 'RolloutWorker': ['actor_rollout'], + 'ReferenceWorker': ['ref_log_prob'], + 'RewardWorker': ['reward_scores'], + 'RuleReward': ['rule_reward'], + 'ComputeAdvantage': ['compute_advantage'], + 'Metrics': ['grpo_metrics'], + } + + # 未提供ActorMetric,直接用Metric + self.metrics = ActorMetric.remote() self.transfer_dock_init() + self.init_metric() + self.kwargs = kwargs self.set_actor_log_prob_skip_flag() + def init_metric(self): + self.actor_worker.sync_init_metrics(self.metrics) + self.ref_worker.sync_init_metrics(self.metrics) + for reward in self.reward_list: + if isinstance(reward, RayActorGroup): + reward.sync_init_metrics(self.metrics) + else: + reward.init_metrics.remote(self.metrics) + def transfer_dock_init(self): - self.transfer_dock = GRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt, - self.metrics, addition_columns=self.dataset_additional_keys) - if is_multimodal(): - self.mm_transfer_dock = MMGRPOTransferDock.remote(self.global_batch_size, self.n_samples_per_prompt) + # 初始化数据节点 + all_columns = set() + for column in self.consumers_columns.values(): + all_columns.update(column) + + num_data_nodes = self.actor_worker.get_distributed_node_num() + + distribute_data_nodes = [] + for data_node_rank in range(num_data_nodes): + storage_node = DistributedStorageNode(self.global_batch_size * self.n_samples_per_prompt, + data_node_rank, + num_data_nodes, + list(all_columns)) + distribute_data_nodes.append(storage_node) + self.distribute_data_nodes = distribute_data_nodes + + # 初始化TD Managers + self.td_managers_all = [] + if self.kwargs.get('distribute_td_managers', False): + self._transfer_dock_init_distribute() + else: + self._transfer_dock_init_single() + + # 在数据节点中注册TD Manager + for node in self.distribute_data_nodes: + node.handler.register_td_managers.remote(self.td_managers_all) + + def _transfer_dock_init_single(self): + consumers_columns_all = list(self.consumers_columns.keys()) + transfer_dock = self.instance_td_manager(consumers_columns_all) + self.metrics_td = transfer_dock + self.compute_advantage_td = transfer_dock + self.td_managers_all.append(transfer_dock) - self.actor_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock) - self.ref_worker.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock) + self.actor_worker.sync_init_transfer_dock(transfer_dock) + self.ref_worker.sync_init_transfer_dock(transfer_dock) for reward in self.reward_list: - if hasattr(reward, 'sync_init_transfer_dock'): - reward.sync_init_transfer_dock(self.transfer_dock, self.mm_transfer_dock) + if isinstance(reward, RayActorGroup): + reward.sync_init_transfer_dock(transfer_dock) else: - reward.init_transfer_dock.remote(self.transfer_dock, self.mm_transfer_dock) + reward.init_transfer_dock.remote(transfer_dock) + + def _transfer_dock_init_distribute(self): + # Actor TD 初始化 + actor_worker_td = self.instance_td_manager(self.TD_managers_map_consumers['ActorHybridWorker']) + self.td_managers_all.append(actor_worker_td) + self.actor_worker.sync_init_transfer_dock(actor_worker_td) + + # reference TD 初始化 + reference_worker_td = self.instance_td_manager(self.TD_managers_map_consumers['ReferenceWorker']) + self.td_managers_all.append(reference_worker_td) + self.ref_worker.sync_init_transfer_dock(reference_worker_td) + + if self.kwargs.get('reward_resource', {}): + reward_worker_td = self.instance_td_manager(self.TD_managers_map_consumers['RewardWorker']) + self.td_managers_all.append(reward_worker_td) + + if self.kwargs.get('rule_reward', None): + rule_reward_td = self.instance_td_manager(self.TD_managers_map_consumers['RuleReward']) + self.td_managers_all.append(rule_reward_td) + + for reward in self.reward_list: + if isinstance(reward, RayActorGroup): + reward.sync_init_transfer_dock(reward_worker_td) + else: + reward.init_transfer_dock.remote(rule_reward_td) + + self.compute_advantage_td = self.instance_td_manager(self.TD_managers_map_consumers['ComputeAdvantage']) + self.metrics_td = self.instance_td_manager(self.TD_managers_map_consumers['Metrics']) + self.td_managers_all += [self.compute_advantage_td, self.metrics_td] def set_actor_log_prob_skip_flag(self): global_batch_size = self.actor_worker.megatron_config.global_batch_size @@ -132,13 +228,10 @@ class RayGRPOTrainer(RayBaseTrainer): logger.info('async start grpo training at iteration: {}/{} ...'.format(iteration, self.train_iters)) while iteration < self.train_iters: - ray.get(self.transfer_dock.clear.remote()) + self.clear_td_managers() batch = next(data_iters) - ray.get(self.transfer_dock.put_prompts_experience.remote(batch, self.dataset_additional_keys)) - if is_multimodal(): - ray.get(self.mm_transfer_dock.clear.remote()) - ray.get(self.mm_transfer_dock.put_experience.remote(batch, indexes=[i for i in range(len(batch['prompts']) * self.n_samples_per_prompt)])) + put_prompts(self.metrics_td, batch, self.n_samples_per_prompt, self.dataset_additional_keys) with Timer(name='iteration', logger=None) as all_timer: # generate sequences @@ -156,14 +249,14 @@ class RayGRPOTrainer(RayBaseTrainer): self.compute_advantage(blocking=False, guarantee_order=self.guarantee_order) # compute reference log_prob - self.ref_worker.compute_ref_log_prob(blocking=self.blocking) + self.ref_worker.compute_ref_log_prob(blocking=self.blocking) # 分布式TD移到上面去了,此处和原代码保持一致 # compute old log_prob if not self.skip_actor_log_prob: self.actor_worker.compute_log_prob(blocking=self.blocking) - self.actor_worker.wait_all_ref_objs_run_over() - + self.actor_worker.wait_all_ref_objs_run_over() # 分布式TD移到上面去了,此处和原代码保持一致 + self.ref_worker.wait_all_ref_objs_run_over() for reward in self.reward_list: if hasattr(reward, 'wait_all_ref_objs_run_over'): @@ -173,15 +266,13 @@ class RayGRPOTrainer(RayBaseTrainer): self.actor_worker.update(self.kl_ctrl, self.skip_actor_log_prob) # collect metrics - grpo_data_metrics = compute_grpo_data_metrics(self.transfer_dock, + grpo_data_metrics = compute_grpo_data_metrics(self.metrics_td, self.global_batch_size * self.n_samples_per_prompt, self.tokenizer, self.global_batch_size * self.n_samples_per_prompt, self.guarantee_order) - metrics_result = ray.get(self.transfer_dock.get_metrics.remote()) - metrics_result = metrics_post_processing(metrics_result) - metrics_result = metrics_sort(metrics_result, all_timer.last) + metrics_result = ray.get(self.metrics.metrics_post_processing.remote()) tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, all_timer.last) update_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, metrics_result["timing/update"]) vllm_tps = compute_tps(self.kwargs, grpo_data_metrics, self.global_batch_size, self.n_samples_per_prompt, metrics_result["timing/rollout"]) @@ -208,7 +299,7 @@ class RayGRPOTrainer(RayBaseTrainer): start_adv_time = time.time() compute_advantage_ref = compute_advantage.options(num_cpus=self.num_cpus_for_local_task).remote( - self.transfer_dock, + self.compute_advantage_td, self.gamma, self.lam, adv_estimator=self.adv_estimator, @@ -221,14 +312,14 @@ class RayGRPOTrainer(RayBaseTrainer): ray.get(compute_advantage_ref) end_adv_time = time.time() ray.get( - self.transfer_dock.update_metrics.remote( - "timing/adv", + self.metrics.update.remote( + "timing/adv", value=[round(end_adv_time, 4), round(start_adv_time, 4)], cumulate=True ) - ) + ) ray.get( - self.transfer_dock.update_metrics.remote( + self.metrics.update.remote( "end_time/end_adv_time", value=[round(end_adv_time, 4)], cumulate=True @@ -237,3 +328,16 @@ class RayGRPOTrainer(RayBaseTrainer): def save_checkpoint(self, iteration: int): self.actor_worker.save_checkpoint(iteration) + + def instance_td_manager(self, consumers_column): + transfer_dock_manager = TransferDockManager.remote(consumers_list=consumers_column, + consumers_columns=self.consumers_columns, + prompts_num=self.global_batch_size, + n_samples_per_prompt=self.n_samples_per_prompt, + distribute_data_nodes=self.distribute_data_nodes) + return transfer_dock_manager + + def clear_td_managers(self): + ray.get(self.metrics.reset.remote()) + for manager in self.td_managers_all: + ray.get(manager.clear.remote()) \ No newline at end of file diff --git a/mindspeed_rl/trainer/utils/__init__.py b/mindspeed_rl/trainer/utils/__init__.py index 07a2455a..c212f9ce 100644 --- a/mindspeed_rl/trainer/utils/__init__.py +++ b/mindspeed_rl/trainer/utils/__init__.py @@ -9,8 +9,7 @@ from .compute_utils import ( ) from .training import get_finetune_data_on_this_tp_rank, broadcast_data -from .transfer_dock import TransferDock, GRPOTransferDock -from .mm_transfer_dock import MMGRPOTransferDock +from .transfer_dock import TransferDock __all__ = [ "AdaptiveKLController", @@ -21,6 +20,4 @@ __all__ = [ "get_last_reward", "compute_grpo_data_metrics", 'TransferDock', - 'GRPOTransferDock', - 'MMGRPOTransferDock' ] diff --git a/mindspeed_rl/trainer/utils/compute_utils.py b/mindspeed_rl/trainer/utils/compute_utils.py index 8aaa110b..f2066211 100644 --- a/mindspeed_rl/trainer/utils/compute_utils.py +++ b/mindspeed_rl/trainer/utils/compute_utils.py @@ -24,7 +24,7 @@ import numpy as np import mindspeed_rl.utils.torch_functional as F from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.utils import generate_mask, get_current_dp_range_indexes -from mindspeed_rl.trainer.utils.transfer_dock import pad_experience +from mindspeed_rl.trainer.utils.transfer_dock import pad_experience, get_experience, put_experience from mindspeed_rl.utils.utils import mstx_timer_decorator @@ -145,17 +145,12 @@ def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer None """ experience_consumer_stage = "compute_advantage" - experience_columns = ["responses", "token_level_rewards", "response_length"] pad_token_id = tokenizer.pad if tokenizer.pad is not None else tokenizer.eod sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, assign_batch_size=global_batch_size) if guarantee_order else None while not ray.get(td.all_consumed.remote(experience_consumer_stage)): - batch_data, index = ray.get( - td.get_experience.remote( - experience_consumer_stage, experience_columns, experience_count, # pad_id=pad_token_id - indexes=sorted_indexes.pop(0) if guarantee_order else None - ) - ) + batch_data, index = get_experience(td, experience_consumer_stage, experience_count, + indexes=sorted_indexes.pop(0) if guarantee_order else None) if batch_data and index: batch_data = pad_experience(batch_data, pad_token_id) # multiple, tp_size response_mask = generate_mask(batch_data["responses"], batch_data["response_length"]) @@ -178,7 +173,7 @@ def compute_advantage(td, gamma, lam, adv_estimator, experience_count, tokenizer "advantages": advantages, "returns": returns, } - td.put_experience.remote(data_dict=output, indexes=index) + put_experience(td, data_dict=output, indexes=index) def get_last_reward(rm_scores, n_sample_batch: int): @@ -215,23 +210,12 @@ def compute_grpo_data_metrics( Dictionary containing various metric values """ experience_consumer_stage = "grpo_metrics" - experience_columns = [ - "rm_scores", - "token_level_rewards", - "responses", - "advantages", - "returns", - "prompt_length", - "response_length", - ] pad_token_id = tokenizer.pad if tokenizer.pad is not None else tokenizer.eod sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, assign_batch_size=global_batch_size) if guarantee_order else None while not ray.get(td.all_consumed.remote(experience_consumer_stage)): - batch, index = ray.get( - td.get_experience.remote(experience_consumer_stage, experience_columns, experience_count, - indexes=sorted_indexes.pop(0) if guarantee_order else None) - ) + batch, index = get_experience(td, experience_consumer_stage, experience_count, + indexes=sorted_indexes.pop(0) if guarantee_order else None) if batch and index: batch = pad_experience(batch, pad_token_id) # multiple, tp_size sequence_score = batch["rm_scores"].sum(-1) diff --git a/mindspeed_rl/trainer/utils/transfer_dock.py b/mindspeed_rl/trainer/utils/transfer_dock.py index 92dad94c..7839b4f8 100644 --- a/mindspeed_rl/trainer/utils/transfer_dock.py +++ b/mindspeed_rl/trainer/utils/transfer_dock.py @@ -4,8 +4,10 @@ import copy import time import threading from abc import ABC -from typing import List, Dict, Union, Optional +from typing import List, Dict, Union, Tuple from operator import itemgetter +import itertools +import copy import ray import torch @@ -24,55 +26,158 @@ class TimeoutException(Exception): super().__init__(message) -class TransferDock(ABC): +class DistributedMetadata: """ - TransferDock is a data structure class that serves as the base class for GRPOTransferDock, - providing data storage and retrieval functions. + get distributed data node information + + elements: + data_node_rank : distributed data node rank + handler : data node ray handler + consumer_columns : consumer columns + indexes : global indexes + experience_offset : experience offset """ + data_node_rank: int + handler: ray.actor.ActorHandle + consumer_columns: List[str] + indexes: List[int] + experience_offset: List[int] - def __init__( - self, - prompts_num: int, - n_samples_per_prompt: int, - experience_columns: Union[List[str], None], - timeout: Union[int, None], - timeout_interval: Union[int, None], - ) -> None: - """TransferDock initialize. - - Args: - prompts_num: The number of prompts loaded from the dataset. - n_samples_per_prompt: The number of responses for a prompt. - experience_columns: Data columns in TransferDock. - timeout: The waiting time for over time printing - timeout_interval: Time interval for timeout printing - """ - super().__init__() +class TransferDockConsumer(ABC): + def __init__(self, consumer: str, consumer_columns: List[str], prompts_num: int, n_samples_per_prompt: int) -> None: self.prompts_num = prompts_num self.n_samples_per_prompt = n_samples_per_prompt self.max_len = prompts_num * n_samples_per_prompt + self.consumer_id = consumer # 消费者ID + self.consumers_columns = consumer_columns # 消费者需要的数据类型 + + ''' + 消费者数据状态表: + 例如: 'actor_logp' 消费者需要 ['prompts', 'responses'] + 则consumers_data_status是一个 self.max_len x 3的矩阵表 + index 'prompts', 'responses', 'status' 包含最后一列是数据状态, 0 数据未ready 1 数据ready 2 数据被消费 + 0 1 0 0 + 1 1 1 1 + 2 1 1 2 + ''' + self.consumers_data_status = torch.zeros(self.max_len, len(consumer_columns) + 1, dtype=torch.int32) + + def get_ready_index(self, experience_count: int): + # 跟之前的sample一致,挑选方式:1)顺序采样,完成break;2)sample;3)有条件挑选 + + # 1) 顺序采样 + # 更新数据状态(此时只有一个进程) + experience_index = [] + for index, index_data in enumerate(self.consumers_data_status): + if index_data[-1] == 1: + experience_index.append(index) + elif index_data[-1] == 0 and index_data[:-1].sum().item() == self.consumers_data_status.shape[1] - 1: + index_data[-1] = 1 # 更新数据状态为待消费 + experience_index.append(index) + if len(experience_index) == experience_count: + break + + if len(experience_index) < experience_count: + return None + + self.consumers_data_status[experience_index, -1] = 2 # 满足条件的数据状态更新为已消费 + return experience_index + + def get_ready_index_n_samples(self, experience_count: int): + hit_lines = [] + match_lines_count = experience_count // self.n_samples_per_prompt + for line in range(0, len(self.consumers_data_status), self.n_samples_per_prompt): + if self.consumers_data_status[line][-1] == 2: + continue + if self.consumers_data_status[line][-1] == 1: + hit_lines.append(line) + else: + current_pair = self.consumers_data_status[line: line + self.n_samples_per_prompt, : -1] + if torch.all(current_pair == 1): + self.consumers_data_status[line: line + self.n_samples_per_prompt, -1] = 1 + hit_lines.append(line) + if len(hit_lines) == match_lines_count: + break + + if len(hit_lines) < match_lines_count: + return None + + experience_index = [] + for line in hit_lines: + self.consumers_data_status[line: line + self.n_samples_per_prompt, -1] = 2 + experience_index.extend(list(range(line, line + self.n_samples_per_prompt))) + return experience_index + + def update_data_write_status(self, put_experience_columns: List[str], put_experience_index: List[int]): + # 更新消费者的数据状态 + for i, column in enumerate(self.consumers_columns): + if column in put_experience_columns: + for index in put_experience_index: + # 消费者数据状态的写入(多进程高并发,写入位置不会互相影响) + self.consumers_data_status[index][i] = 1 + + def ready_for_consume(self, indexes: List[int]) -> bool: + """ + check all indexes is ready + """ + for index in indexes: + index_data = self.consumers_data_status[index] + if index_data[-1] > 0: + continue - self.experience_columns = ( - experience_columns if experience_columns is not None else [] - ) + # not update status while in reading stage + if index_data[:-1].sum().item() < self.consumers_data_status.shape[1] - 1: + return False + + return True + + def update_consumers_data_status(self, indexes: List[int], column_idx: int, status: int): + self.consumers_data_status[[indexes], column_idx] = status + + def all_consumed(self): + """check all consumed + + Args: + Returns: + the last column all 2 return True, else False + """ + return self.consumers_data_status[:, -1].sum() == 2 * self.max_len + + def clear(self): + self.consumers_data_status = torch.zeros(self.max_len, len(self.consumers_columns) + 1, dtype=torch.int32) + + +@ray.remote(max_concurrency=16, num_cpus=16) +class TransferDockData: + def __init__( + self, + max_len: int, + data_node_rank: int, + num_distribute_nodes: int, + experience_columns: Union[List[str], None] + ): + self.max_len = max_len + self.data_node_rank = data_node_rank + self.num_distribute_nodes = num_distribute_nodes + self.experience_columns = experience_columns \ + if experience_columns is not None else [] self.experience_data = { - key: [None for _ in range(self.max_len)] - for key in self.experience_columns - } - self.experience_data_status = { - key: torch.zeros(self.max_len, dtype=torch.int32) - for key in self.experience_columns + key: [None for _ in range(self.max_len)] for key in self.experience_columns } - self.timeout = timeout if timeout is not None else 300 - self.timeout_interval = timeout_interval if timeout_interval is not None else 5 + self.all_td_managers = None + + def _convert_index_global_to_local(self, indexes: List[int]): + return [index % (self.max_len // self.num_distribute_nodes) + for index in indexes] - def _put( + def put_data( self, experience_columns: List[str], - experience: List[List[List[torch.Tensor]]], - indexes: List[int] = None, + experience: List[List[Union[Tensor, List[Tensor]]]], + experience_offset: List[int], + indexes: List[int], ): """Put data into specified columns and rows. @@ -94,34 +199,53 @@ class TransferDock(ABC): tensor([4, 4, 4, 4]) ] ] + experience_offset: experience offset indexes: Rows to put data in. [0, 1, 2, 4] + manager_handler: ray manager handler Returns: None """ + # If experience_columns not in TD, raise ValueError for experience_column in experience_columns: if experience_column not in self.experience_columns: - raise ValueError( - f"put experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}" - ) + raise ValueError(f"Put data ERROR: {experience_column} not in TD") - if not indexes: - raise ValueError( - "put experience into TD without indexes, indexes must be provided" - ) + if indexes is None: + raise ValueError(f"put experience ERROR: indexes is None") - if max(indexes) >= self.max_len: - raise ValueError( - f"Put experience index {max(indexes)} exceeds the Transfer Dock range {self.max_len}." - ) + self._put_with_index(experience_columns, experience, + self._convert_index_global_to_local(indexes), + experience_offset) + for manager_handler in self.all_td_managers: + # ray.get(manager_handler.report_data_produce_status.remote(experience_columns, indexes)) + ray.get(manager_handler.report_data_produce_status.remote(experience_columns=experience_columns, indexes=indexes)) + def _put_with_index( + self, + experience_columns: List[str], + experience: List[List[List[torch.Tensor]]], + local_indexes: List[int], + experience_offset: List[int] + ): + """Put data into specified columns and rows. + + Args: + experience_columns: Columns to put data in. + experience: Data for the corresponding columns. + local_indexes: Local index to put data in + experience_offset: experience offset + + Returns: None + + """ for column_idx, single_column in enumerate(experience_columns): - for i, index in enumerate(indexes): - self.experience_data[single_column][index] = experience[column_idx][i] - self.experience_data_status[single_column][index] = 1 + for local_index, offset in zip(local_indexes, experience_offset): + # input can rewrite experience data in TD + self.experience_data[single_column][local_index] = experience[column_idx][offset] - def _get(self, experience_columns: List[str], indexes: List[int]): + def get_data(self, experience_columns: List[str], indexes: List[int]): """Get data based on row and column numbers. Args: @@ -147,509 +271,285 @@ class TransferDock(ABC): ] """ - if len(indexes) == 0: - return [[] for _ in range(len(experience_columns))] - - if max(indexes) >= self.max_len: - raise ValueError( - f"Get experience index {max(indexes)} exceeds the Transfer Dock range {self.max_len}." - ) - + local_indexes = self._convert_index_global_to_local(indexes) experience = [] for single_column in experience_columns: - self._wait_for_data(single_column, indexes) - if len(indexes) == 1: - experience.append([self.experience_data[single_column][indexes[0]]]) + if len(local_indexes) == 1: + experience.append([self.experience_data[single_column][local_indexes[0]]]) else: - experience.append(list(itemgetter(*indexes)(self.experience_data[single_column]))) + experience.append(list(itemgetter(*local_indexes)(self.experience_data[single_column]))) return experience - def _wait_for_data(self, single_column: str, indexes: List[int]): - """Wait for data in which column and row to be ready. - Args: - single_column: Column that need to wait for data to be ready. - indexes: Rows that need to wait for data to be ready. + def register_td_managers(self, td_managers_list): + self.all_td_managers = td_managers_list - Returns: None - - """ - if len(indexes) == 1: - data_ready = self.experience_data_status[single_column][indexes] == 1 - else: - data_ready = sum(itemgetter(*indexes)(self.experience_data_status[single_column])) == len(indexes) - - start_time = time.time() - while not data_ready: - elapsed_time = time.time() - start_time - if ( - elapsed_time > self.timeout - and elapsed_time % self.timeout_interval < 0.1 - ): - logger.warning(f"TIMEOUT: data_ready has slept {elapsed_time} second") - time.sleep(0.1) - if len(indexes) == 1: - data_ready = self.experience_data_status[single_column][indexes] == 1 - else: - data_ready = sum( - itemgetter(*indexes)(self.experience_data_status[single_column]) - ) == len(indexes) - def _clear_experience_data_and_status(self, indexes=None): - """Clear data and data status in TransferDock. + def clear(self): + """Clear local data. Returns: None """ - if indexes is None: - self.experience_data = { - key: [None for _ in range(self.max_len)] - for key in self.experience_columns - } - self.experience_data_status = { - key: torch.zeros(self.max_len, dtype=torch.int32) - for key in self.experience_columns - } - else: - for key in self.experience_columns: - self.experience_data_status[key][indexes] = 0 - for key in self.experience_columns: - for idx in indexes: - self.experience_data[key][idx] = None - - def get_experience_data(self): - """Get all data in TransferDock. + self.experience_data = {key: [None for _ in range(self.max_len)] for key in self.experience_columns} - Returns: Data dict. - """ - return self.experience_data - - def get_experience_status(self): - """Get all data status in TransferDock. - - Returns: Data status dict. - - """ - return self.experience_data_status - - def get_experience_len(self): - """Get the maximum length of data in TransferDock. - - Returns: The maximum length of data. - - """ - return self.max_len +class DistributedStorageNode: + def __init__( + self, + max_len: int, + data_node_rank: int, + num_distribute_nodes: int, + experience_columns: Union[List[str], None] + ): + self.data_node_rank = data_node_rank + if max_len % num_distribute_nodes != 0: + raise ValueError(f"max_len:{max_len} need to be divisible " + f"by num_distribute_nodes:{num_distribute_nodes}") + + self.handler = TransferDockData.remote( + max_len=max_len, + data_node_rank=data_node_rank, + num_distribute_nodes=num_distribute_nodes, + experience_columns=experience_columns + ) -@ray.remote(max_concurrency=100, num_cpus=10) -class GRPOTransferDock(TransferDock): +class TransferDock(ABC): """ - GRPOTransferDock is based on TransferDock and supports managing data transfer between - GRPO asynchronous tasks in the Ray cluster. + TransferDock is a data structure class that serves as the base class for GRPOTransferDock, + providing data storage and retrieval functions. """ def __init__( self, prompts_num: int, n_samples_per_prompt: int, - metrics=None, - addition_columns: Union[List[str], None] = None, - addition_consumers: Union[List[str], None] = None, - timeout: Union[int, None] = None, - timeout_interval: Union[int, None] = None, + consumers_columns: Dict[str, List[str]], + distribute_nodes=None ) -> None: - """GRPOTransferDock initialize. + """TransferDock initialize. Args: prompts_num: The number of prompts loaded from the dataset. n_samples_per_prompt: The number of responses for a prompt. - metrics: The metrics stored in TransferDock. - addition_columns: Additional experience columns in TransferDock. - addition_consumers: Additional consumers in TransferDock. - timeout: The waiting time for over time printing. - timeout_interval: Time interval for timeout printing. + consumers_columns: Data columns in TransferDock. + distribute_nodes: Distribute nodes across multiple processes. """ - self.experience_columns = [ - "prompts", - "prompt_length", - "responses", - "response_length", - "attention_mask", - "labels", - "input_ids", - "input_ids_length", - "actor_rollout", - "rm_scores", - "token_level_rewards", - "old_log_prob", - "ref_log_prob", - "advantages", - "returns", - ] - self.experience_consumers = [ - "trainer", - "actor_rollout", - "actor_log_prob", - "ref_log_prob", - "actor_train", - "compute_advantage", - "rule_reward", - "reward_scores", - "grpo_metrics", - ] - if addition_columns: - for column in addition_columns: - if column not in self.experience_columns: - self.experience_columns.append(column) - - if addition_consumers: - for consumer in addition_consumers: - if consumer not in self.experience_consumers: - self.experience_consumers.append(consumer) - - super().__init__( - prompts_num, - n_samples_per_prompt, - self.experience_columns, - timeout, - timeout_interval, - ) - self.experience_consumer_status = { - key: torch.zeros(self.max_len, dtype=torch.int32) - for key in self.experience_consumers - } - self.consumer_sampling_lock = { - key: threading.Lock() - for key in self.experience_consumers - } - self.metrics = metrics - - def get_metrics(self): - return self.metrics - - def update_metrics(self, key="", value=None, cumulate=False): - self.metrics.update(key, value, cumulate=cumulate) + super().__init__() + self.prompts_num = prompts_num + self.n_samples_per_prompt = n_samples_per_prompt + self.max_len = prompts_num * n_samples_per_prompt - def get_experience( + # 消费者每次消耗数据类型 + self.consumers_columns = consumers_columns + # 消费者TD初始化 + self.consumers = {} + + for key, value in self.consumers_columns.items(): + self.consumers[key] = TransferDockConsumer(key, value, prompts_num, n_samples_per_prompt) + + # 采样状态锁,只能有一个dp线程进入选择ready的样本 + self.consumer_sampling_lock = {key: threading.Lock() for key in self.consumers_columns} + + self.distribute_nodes = distribute_nodes + self.num_distribute_nodes = len(distribute_nodes) + + def _get_meta_info(self, indexes: List[int] = None, consumer: str = None) -> List[DistributedMetadata]: + node_range = self.max_len // self.num_distribute_nodes + node_set = set() + metadata_list = [] + experience_offset = 0 + + if not indexes: ## TODO CHECK + indexes = list(range(self.max_len)) ## TODO CHECK + for index in indexes: + data_node_rank = index // node_range + if data_node_rank not in node_set: + node_set.add(data_node_rank) + metadata = DistributedMetadata() + metadata.data_node_rank = data_node_rank + metadata.handler = self.distribute_nodes[data_node_rank].handler + metadata.consumer_columns = self.consumers_columns[consumer] if consumer else None + metadata.indexes = [] + metadata.experience_offset = [] + metadata_list.append(metadata) + metadata_list[-1].indexes.append(index) + metadata_list[-1].experience_offset.append(experience_offset) + experience_offset += 1 + + return metadata_list + + def update_consumers_data_status(self, consumer: str, column_idx: int, indexes: List[int], status: int) -> None: + self.consumers[consumer].update_consumers_data_status(indexes, column_idx, status) + + def get_consumers_columns(self): + return self.consumers_columns + + def get_consumer_columns(self, consumer): + return self.consumers[consumer].consumers_columns + + def get_metadata_on_reading( self, consumer: str, - experience_columns: List[str], - experience_count: int = None, + experience_count: int, + order_preserving_flag: bool, indexes: List[int] = None, - get_n_samples: bool = True, - ): - """Get padded experience data from GRPOTransferDock. + get_n_samples: bool = True + ) -> Union[List[DistributedMetadata], None]: + """input global indexes Args: - consumer: GRPO task stage to get in. - experience_columns: Columns from which to get data. - experience_count: Number of data to get. - indexes: Rows from which to get data. - pad_id: Pad token. - multiple: The multiple of TP to pad. - get_n_samples: Whether to get n samples at the same time. - target_seq_len: Target sequence length. + consumer : consumer column + experience_count : consume data size + order_preserving_flag : order preserving flag + indexes : want to get indexes + get_n_samples : Whether to get n samples at the same time. - Returns: Data dict and row numbers. + Returns: + node handler and local indexes """ - if consumer not in self.experience_consumers: - raise ValueError( - f"get experience ERROR: {consumer} not in TD experience_consumers {self.experience_consumers}" - ) - - for experience_column in experience_columns: - if experience_column not in self.experience_columns: - raise ValueError( - f"get experience ERROR: {experience_column} not in TD experience_column {self.experience_columns}" - ) - - if indexes is None: - if experience_count > self.max_len: - raise ValueError( - f"TD max_len: {self.max_len} need >= experience_count: {experience_count}" - ) - - if self.max_len % experience_count != 0: - raise ValueError( - f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}" - ) - - if get_n_samples: - if experience_count % self.n_samples_per_prompt != 0: - raise ValueError( - f"get_n_samples need experience_count:{experience_count} must be divisible by " - f"n_samples_per_prompt: {self.n_samples_per_prompt}" - ) - indexes = self._sample_ready_index_n_samples( - consumer, experience_count, experience_columns - ) - else: - indexes = self._sample_ready_index( - consumer, experience_count, experience_columns - ) - - if not indexes: - return None, None - experience = self._get(experience_columns, indexes) - else: - self.experience_consumer_status[consumer][indexes] = 1 - experience = self._get(experience_columns, indexes) - - experience_batch = {} - for i, experience_column in enumerate(experience_columns): - experience_batch[experience_column] = experience[i] - return experience_batch, indexes - - def put_experience( - self, - data_dict: Dict[str, Union[Tensor, List[Tensor]]], - indexes: List[int] = None, - ): - """Put data into specified columns and rows. + if indexes is not None: + if order_preserving_flag and \ + not self.consumers[consumer].ready_for_consume(indexes): + return None - Args: - data_dict: Data dict to put in GRPOTransferDock. - indexes: Rows to put data in. + # 先更新状态再取数据,更新消费者的数据状态为已消费 + if consumer != 'actor_partial_rollout': + self.consumers[consumer].update_consumers_data_status(indexes, -1, 2) + return self._get_meta_info(indexes, consumer) - Returns: None + if experience_count > self.max_len: + raise ValueError(f"TD max_len: {self.max_len} need >= experience_count: {experience_count}") - """ + if consumer != "actor_partial_rollout" and self.max_len % experience_count != 0: + raise ValueError(f"TD max_len:{self.max_len} need be divisible by experience_count: {experience_count}") + # 根据消费者状态取ready的数据 + if get_n_samples: + if experience_count % self.n_samples_per_prompt != 0: + raise ValueError(f"get_n_samples need experience_count:{experience_count} " + f"must be divisible by n_samples_per_prompt: {self.n_samples_per_prompt}") + indexes = self.get_ready_index_n_samples(consumer, experience_count) + else: + indexes = self.get_ready_index(consumer, experience_count) if not indexes: - raise ValueError( - "put experience into TD without indexes, indexes must be provided" - ) - experience_columns, experience = trans_input_to_experience(data_dict) - self._put(experience_columns, experience, indexes) - - def put_prompts_experience( - self, batch: Dict[str, Tensor], dataset_additional_keys: List[str] = None - ): - """Put data into specified columns and rows. + return None + return self._get_meta_info(indexes, consumer) - Args: - batch: Batch datas from original dataloader. - dataset_additional_keys: The additional experience types from the dataset. - - Returns: None + def get_ready_index(self, consumer: str, experience_count: int) -> List[int]: - """ - - prompts = batch["prompts"] - prompt_length = [] - for prompt in prompts: - for _ in range(self.n_samples_per_prompt): - prompt_length.append(torch.tensor([len(prompt)])) - - prompts_data = prompts - prompts = [] - for prompt in prompts_data: - for _ in range(self.n_samples_per_prompt): - prompts.append(copy.deepcopy(prompt)) - - add_vals = {} - for add_keys in dataset_additional_keys: - if add_keys in batch.keys(): - values = [] - for value in batch[add_keys]: - for _ in range(self.n_samples_per_prompt): - values.append(value) - add_vals[add_keys] = values - - indexes = [i for i in range(len(prompt_length))] - data_dict = dict( - {"prompt_length": prompt_length, "prompts": prompts}, **add_vals - ) - experience_columns, experience = trans_input_to_experience(data_dict) - - self._put(experience_columns, experience, indexes) - - def _sample_ready_index( - self, - consumer: str, - experience_count: int, - experience_columns: List[str], - target_seq_len: int = None, - ) -> Optional[List[int]]: - """Randomly select a specified number of prepared experiences from TransferDock. - - Args: - consumer: GRPO task stage to sample in. - experience_count: Number for rows to sample. - experience_columns: Columns from which to sample. + with self.consumer_sampling_lock[consumer]: + sampled_indexes = self.consumers[consumer].get_ready_index(experience_count) - Returns: Sampled row numbers. + return sampled_indexes - """ + def get_ready_index_n_samples(self, consumer: str, experience_count: int) -> List[int]: with self.consumer_sampling_lock[consumer]: - not_consumed_indexes = self.experience_consumer_status[consumer] == 0 - data_ready_indexes = torch.all( - torch.stack( - [self.experience_data_status[single_column] == 1 for single_column in experience_columns] - ), dim=0, - ) - usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0] - - if len(usable_indexes) < experience_count: - return None - - if experience_count > 0: - sampled_indexes = self.batch_balencing_sampler( - experience_columns, usable_indexes, experience_count, target_seq_len - ) - self.experience_consumer_status[consumer][sampled_indexes] = 1 - else: - sampled_indexes = None + sampled_indexes = self.consumers[consumer].get_ready_index_n_samples(experience_count) return sampled_indexes - def _sample_ready_index_n_samples( - self, - consumer: str, - experience_count: int, - experience_columns: List[str], - target_seq_len: int = None, - ) -> Optional[List[int]]: - """Randomly select a specified number of prepared experiences from TransferDock at multiples of n_sample. + def get_metadata_on_writing(self, indexes: List[int]) \ + -> List[DistributedMetadata]: + if indexes is None: + raise ValueError(f"put experience ERROR: indexes is None") + return self._get_meta_info(indexes) - Args: - consumer: GRPO task stage to sample in. - experience_count: Number for rows to sample. - experience_columns: Columns from which to sample. - target_seq_len: Sample according with seq_len and target_seq_len. + def get_all_metadata(self, consumer: str = None): + return self._get_meta_info(consumer=consumer) - Returns: Sampled row numbers. - """ - experience_count_n_samples = experience_count // self.n_samples_per_prompt - with self.consumer_sampling_lock[consumer]: - experience_consumer_status_n_samples = ( - 1 - torch.all( - torch.tensor( - torch.reshape( - self.experience_consumer_status[consumer], - (self.prompts_num, self.n_samples_per_prompt), - ) == 0 - ), dim=1, - ).int() - ) - not_consumed_indexes = experience_consumer_status_n_samples == 0 - - experience_data_status_n_samples = {} - for key, value in self.experience_data_status.items(): - experience_data_status_n_samples[key] = torch.all( - torch.tensor( - torch.reshape(value, (self.prompts_num, self.n_samples_per_prompt)) == 1 - ), dim=1, - ).int() - - data_ready_indexes = torch.all( - torch.stack( - [experience_data_status_n_samples.get(single_column) == 1 for single_column in experience_columns]), - dim=0, - ) - - usable_indexes = (not_consumed_indexes & data_ready_indexes).nonzero(as_tuple=True)[0] - - if len(usable_indexes) < experience_count_n_samples: - return None + def report_data_produce_status(self, experience_columns: List[str], indexes: List[int]): + """report put data status - sampled_indexes_n_sample = self.batch_balencing_sampler( - experience_columns, - usable_indexes, - experience_count_n_samples, - target_seq_len, - ) + Args: + experience_columns: Columns from which to get data. + indexes: Rows from which to get data. - sampled_indexes = [] - for n_sample_index in sampled_indexes_n_sample: - index_list = [] - for index in range( - n_sample_index * self.n_samples_per_prompt, - (n_sample_index + 1) * self.n_samples_per_prompt - ): - index_list.append(index) + """ + for _, td_consumer in self.consumers.items(): + td_consumer.update_data_write_status(experience_columns, indexes) - sampled_indexes += index_list + def get_experience_len(self): + """Get the maximum length of data in TransferDock. - self.experience_consumer_status[consumer][sampled_indexes] = 1 + Returns: The maximum length of data. - return sampled_indexes + """ + return self.max_len def all_consumed(self, consumer: str): - """If consumer has consumed all data in GRPOTransferDock. + """If consumer has consumed all data in TransferDock. Args: - consumer: GRPO task stage to consume in. + consumer: + """ + return self.consumers[consumer].all_consumed() + + def get_consumer_status(self, consumer: str): + """Get consumer status. - Returns: True or False. + Returns: Consumer status dict. """ - return self.experience_consumer_status[consumer].sum() == self.max_len + return self.consumers[consumer].consumers_data_status def clear(self): - """Reset consumer status.Clear data and data status in GRPOTransferDock. + """Reset consumer status.Clear data and data status in TransferDock. Returns: None - """ - self.experience_consumer_status = { - key: torch.zeros(self.max_len, dtype=torch.int32) - for key in self.experience_consumers - } - self.metrics.reset() - self._clear_experience_data_and_status() + for key in self.consumers_columns.keys(): + self.consumers[key].clear() + refs = [] + for node in self.distribute_nodes: + refs.append(node.handler.clear.remote()) + ray.get(refs) - def get_consumer_status(self): - """Get consumer status. - Returns: Consumer status dict. - - """ - return self.experience_consumer_status +@ray.remote(max_concurrency=100, num_cpus=10) +class TransferDockManager(TransferDock): + def __init__( + self, + consumers_list: Union[str, List[str]], + consumers_columns: Dict[str, List[str]], + prompts_num: int, + n_samples_per_prompt: int, + distribute_data_nodes=None, + ) -> None: - def batch_balencing_sampler( - self, experience_columns, usable_indexes, experience_count, target_seq_len=None - ): - if target_seq_len is None: - weights = torch.ones(len(usable_indexes)) - else: - seq_len = torch.tensor( - [ - sum([self.experience_data[key][idx].numel() for key in experience_columns]) - for idx in usable_indexes - ] - ) - weights = torch.sigmoid(1 / (torch.abs(seq_len - target_seq_len) + 0.001), dim=0) + manager_consumers_columns = {} + for key in consumers_list: + if key in consumers_columns: + manager_consumers_columns.update({key: consumers_columns[key]}) + else: + raise ValueError(f"consumer stage {key} is not in consumers_columns") - sampled_indexes_idx = torch.multinomial(weights, experience_count, replacement=False).tolist() - sampled_indexes = [int(usable_indexes[i]) for i in sampled_indexes_idx] - - return sampled_indexes + super().__init__(prompts_num, n_samples_per_prompt, manager_consumers_columns, distribute_data_nodes) def pad_experience( - experience_batch: Dict[str, List[Tensor]], - pad_id: int, - multiple: int = 1, + experience_batch: Dict[str, List[Tensor]], + pad_id: int, + multiple: int = 1, ): """ Pad dict data. Args: experience_batch: Dict { - 'prompts': [ tensor([1, 1, 1, 1]), + 'prompts': [ tensor([1, 1, 1, 1]), tensor([2, 2, 2, 2]), - tensor([3, 3, 3, 3]), - tensor([4, 4, 4, 4])], - 'attention_mask': [ tensor([1]), - tensor([2, 2]), - tensor([3, 3, 3]), - tensor([4, 4, 4, 4])], + tensor([3, 3, 3, 3]), + tensor([4, 4, 4, 4])], + 'attention_mask': [ tensor([1]), + tensor([2, 2]), + tensor([3, 3, 3]), + tensor([4, 4, 4, 4])], } pad_id: Pad token. 0.0 @@ -695,8 +595,6 @@ def pad_experience( for experience_column, experience in experience_batch.items(): if experience_column in ["prompt_length", "response_length"]: padded = torch.cat(experience).reshape(-1, 1) - elif experience_column in ["position_ids"]: - padded = pad_sequence(experience, batch_first=True, padding_value=pad_id) elif experience[0].is_floating_point(): padded = pad_multiples(experience, pad_id=0.0, multiple=multiple) else: @@ -763,19 +661,139 @@ def trans_input_to_experience(experience_dict: Dict[str, Union[Tensor, List[Tens return experience_columns, experience_list -def pack_experience_columns(experience_dict, experience_count): +def get_experience( + manager_handler: ray.actor.ActorHandle, + consumer: str, + consume_batch_size: int, + order_preserving_flag: bool = False, + indexes: List[int] = None, + get_n_samples: bool = True +) -> Union[Tuple[Dict[str, Tensor], List[int]], Tuple[None, None]]: + """Get padded experience data from GRPOTransferDock. + + Args: + manager_handler: manager ray handler + consumer: GRPO task stage to get in. + consume_batch_size: consume data size + order_preserving_flag: order preserving flag + indexes: Rows from which to get data. + if is None, make sure order_preserving_flag is False + pad_id: Pad token. + multiple: The multiple of TP to pad. + get_n_samples: Whether to get n samples at the same time. + + Returns: Data dict and row numbers. """ + if indexes is None and order_preserving_flag: + raise ValueError(f"get experience ERROR: indexes is None and order_preserving_flag is True") + + metadata_list = ray.get(manager_handler.get_metadata_on_reading. + remote(consumer, consume_batch_size, order_preserving_flag, indexes, get_n_samples)) + if metadata_list is None: + time.sleep(0.1) + return None, None + + indexes = list(itertools.chain.from_iterable( + [metadata.indexes for metadata in metadata_list])) \ + if indexes is None else indexes + refs = [metadata.handler.get_data.remote(metadata.consumer_columns, metadata.indexes) + for metadata in metadata_list] + ray_results = ray.get(refs) + + experience = [sum((sub[sub_idx] for sub in ray_results), []) + for sub_idx in range(len(ray_results[0]))] + + experience_batch = {} + experience_columns = metadata_list[0].consumer_columns + for i, ec in enumerate(experience_columns): + experience_batch[ec] = experience[i] + + return experience_batch, indexes + + +def put_experience( + manager_handler: ray.actor.ActorHandle, + data_dict: Dict[str, Union[Tensor, List[Tensor]]], + indexes: List[int] +) -> None: + """Put data into specified columns and rows. + + Args: + manager_handler: manager ray handler + data_dict: Data dict to put in TransferDock. + indexes: Rows to put data in. + + Returns: None + + """ + experience_columns, experience = trans_input_to_experience(data_dict) + + metadata_list = ray.get(manager_handler.get_metadata_on_writing.remote(indexes)) + if metadata_list is None: + logger.warning("put experience return None.") + return + + ray.get([metadata.handler.put_data.remote( + experience_columns, experience, metadata.experience_offset, metadata.indexes) + for metadata in metadata_list]) + + +def put_prompts( + manager_handler: ray.actor.ActorHandle, + batch: Dict[str, Tensor], + n_samples_per_prompt: int, + dataset_additional_keys: List[str] +): + """Put data into specified columns and rows. + + Args: + manager_handler: manager ray handler + batch: Batch datas from original dataloader. + n_samples_per_prompt: The number of samples generated per prompt. + dataset_additional_keys: The additional experience types from the dataset. + + Returns: None + + """ + prompts = batch['prompts'] + prompt_length = [] + for prompt in prompts: + for _ in range(n_samples_per_prompt): + prompt_length.append(torch.tensor([len(prompt)])) + + prompts_data = prompts + prompts = [] + for prompt in prompts_data: + for _ in range(n_samples_per_prompt): + prompts.append(copy.deepcopy(prompt)) + + add_vals = {} + for add_keys in dataset_additional_keys: + if add_keys in batch.keys(): + values = [] + for value in batch[add_keys]: + for _ in range(n_samples_per_prompt): + values.append(value) + add_vals[add_keys] = values + + indexes = [i for i in range(len(prompt_length))] + data_dict = dict({'prompt_length': prompt_length, 'prompts': prompts}, **add_vals) + put_experience(manager_handler, data_dict, indexes) + + +def pack_experience_columns(experience_dict, experience_count): + """ Compress experiences by packing tensors into ONE. from experience_dict { - 'prompts': [ tensor([1, 1, 1]), + 'prompts': [ tensor([1, 1, 1]), tensor([2, 2, 2, 2]), - tensor([3, 3, 3]), - tensor([4, 4, 4, 4])], - 'attention_mask': [ tensor([1]), - tensor([2, 2]), - tensor([3, 3, 3]), - tensor([4, 4, 4, 4])], + tensor([3, 3, 3]), + tensor([4, 4, 4, 4])], + 'attention_mask': [ tensor([1]), + tensor([2, 2]), + tensor([3, 3, 3]), + tensor([4, 4, 4, 4])], } To batch_data { @@ -784,7 +802,7 @@ def pack_experience_columns(experience_dict, experience_count): } batch_data_length { - 'prompts': tensor([3, 4, 3, 4]), + 'prompts': tensor([3, 4, 3, 4]), 'attention_mask': tensor([1, 2, 3, 4]) } """ @@ -798,40 +816,14 @@ def pack_experience_columns(experience_dict, experience_count): for key, value in experience_dict.items(): if len(value) != experience_count: raise ValueError(f"ERROR: when pack, experience '{key}' number does not match experience_count") + packed_experience = [] + data_length = [] + for i in range(experience_count): + packed_experience.extend(value[i].tolist()) + data_length.append(len(value[i])) - # 判断是一维张量还是二维张量 - is_2d = len(value[0].shape) > 1 - if is_2d: - # 处理二维张量,如position_ids - first_dim = value[0].shape[0] - # 确保所有张量的第一维相同 - for i in range(experience_count): - if value[i].shape[0] != first_dim: - raise ValueError(f"ERROR: when pack 2D tensor, first dimension must be the same for all experiences") - - # 准备存储连接后的二维张量 - packed_data = [] - for dim_idx in range(first_dim): - dim_data = [] - for i in range(experience_count): - dim_data.extend(value[i][dim_idx].tolist()) - packed_data.append(dim_data) - - batch_data[key] = torch.tensor(packed_data, dtype=value[0].dtype) - - # 仅记录第二维的长度 - data_length = [value[i].shape[1] for i in range(experience_count)] - batch_data_length[key] = torch.tensor(data_length, dtype=torch.int32) - else: - # 原有的一维张量处理逻辑 - packed_experience = [] - data_length = [] - for i in range(experience_count): - packed_experience.extend(value[i].tolist()) - data_length.append(len(value[i])) - - batch_data[key] = torch.tensor(packed_experience, dtype=value[0].dtype) - batch_data_length[key] = torch.tensor(data_length, dtype=torch.int32) + batch_data[key] = torch.tensor(packed_experience, dtype=value[0].dtype) + batch_data_length[key] = torch.tensor(data_length, dtype=torch.int32) return batch_data, batch_data_length @@ -882,64 +874,28 @@ def unpack_pad_experience(batch_data, batch_data_length, pad_id, multiple): lengths = length_list.to(data_device) - # 判断是一维还是二维张量 - is_2d = len(data.shape) > 1 - if is_2d: - # 处理二维张量,如position_ids - first_dim = data.shape[0] - - # 计算最大长度 - max_row_len = torch.max(lengths).item() - if multiple > 1: - max_row_len = ((max_row_len + multiple - 1) // multiple) * multiple - - # 创建结果张量,每个样本是一个单独的2D张量 - sample_count = len(lengths) - result = [] - - # 预分配张量 - if data[0].is_floating_point(): - padded_tensor = torch.full((sample_count, first_dim, max_row_len), 0.0, - dtype=data_dtype, device=data_device) - else: - padded_tensor = torch.full((sample_count, first_dim, max_row_len), pad_id, - dtype=data_dtype, device=data_device) - - # 计算累积长度 - cum_length = torch.cat([torch.tensor([0], device=data_device), - torch.cumsum(lengths, 0)]) - - # 填充每个样本 - for i in range(sample_count): - seq_len = lengths[i] - for dim_idx in range(first_dim): - start_idx = cum_length[i] - end_idx = cum_length[i] + seq_len - padded_tensor[i, dim_idx, :seq_len] = data[dim_idx, start_idx:end_idx] - - padded_batch_data[key] = padded_tensor - else: - # 原有的一维张量处理逻辑 - # 计算最大长度 - max_row_len = torch.max(lengths).item() - if multiple > 1: - max_row_len = ((max_row_len + multiple - 1) // multiple) * multiple - - # 预分配张量 - if data.is_floating_point(): - padded_tensor = torch.full((len(lengths), max_row_len), 0.0, + # 计算最大长度 + max_row_len = torch.max(lengths).item() + if multiple > 1: + max_row_len = ((max_row_len + multiple - 1) // multiple) * multiple + + # 预分配张量 + if data[0].is_floating_point(): + padded_tensor = torch.full((len(lengths), max_row_len), 0.0, dtype=data_dtype, device=data_device) - else: - padded_tensor = torch.full((len(lengths), max_row_len), pad_id, + else: + padded_tensor = torch.full((len(lengths), max_row_len), pad_id, dtype=data_dtype, device=data_device) - # 向量化填充 - cum_length = torch.cat([torch.tensor([0], device=data_device + # 向量化填充 + cum_length = torch.cat([torch.tensor([0], device=data_device ), torch.cumsum(lengths, 0)]) - for i, _ in enumerate(lengths): - seq_len = lengths[i] - padded_tensor[i, :seq_len] = data[cum_length[i]:cum_length[i + 1]] - padded_batch_data[key] = padded_tensor + for i, _ in enumerate(lengths): + seq_len = lengths[i] + padded_tensor[i, :seq_len] = data[cum_length[i]:cum_length[i + 1]] + padded_batch_data[key] = padded_tensor return padded_batch_data + + diff --git a/mindspeed_rl/utils/metrics.py b/mindspeed_rl/utils/metrics.py index b4c1a3e2..8718a1b9 100644 --- a/mindspeed_rl/utils/metrics.py +++ b/mindspeed_rl/utils/metrics.py @@ -5,6 +5,7 @@ from typing import Dict import numpy as np import torch +import ray class Metric(ABC): @@ -109,3 +110,42 @@ class Metric(ABC): 重置指标状态。 """ self.metric = {} + + def metrics_post_processing(self) -> Dict[str, torch.Tensor]: + """ + Calculate the mean of each metric across a list of metric dictionaries. + + Args: + metrics_list: A list of dictionaries, where each dictionary contains metrics as key-value pairs. + + Returns: + metrics_mean: A dictionary where each key is a metric name and + each value is the mean of that metric across all batches. + """ + new_metrics = {} + for key, value in self.metric.items(): + if "timing" in key: + if isinstance(value, list): + new_metrics[key] = self.compute_max(key, value) - self.compute_min(key, value) + else: + new_metrics[key] = value + elif "start_time" in key: + if isinstance(value, list): + new_metrics[key] = self.compute_min(key, value) + else: + new_metrics[key] = value + elif "end_time" in key: + if isinstance(value, list): + new_metrics[key] = self.compute_max(key, value) + else: + new_metrics[key] = value + elif isinstance(value, list): + new_metrics[key] = self.compute_mean(key, value) + else: + new_metrics[key] = value + return new_metrics + + +@ray.remote +class ActorMetric(Metric): + pass \ No newline at end of file diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index e4f5edb6..8a8235c8 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -21,7 +21,7 @@ from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.utils.utils import MsProbe from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronShardingManager, MegatronOffLoader -from mindspeed_rl.utils.utils import num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator, profiler_start, profiler_step, is_multimodal +from mindspeed_rl.utils.utils import num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator, profiler_start, profiler_step from mindspeed_rl.utils.pad_process import remove_padding_and_split_to_list, truncate_rows @@ -89,7 +89,7 @@ class ActorHybridWorkerBase(BaseWorker): self.inference_model = self._build_rollout() self.sharding_manager = self._build_sharding_manager() megatron_module = self.get_megatron_module() - + if self.generate_config.offload_train_param: self.actor_offloader.onload_param() @@ -123,17 +123,22 @@ class ActorHybridWorkerBase(BaseWorker): self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role) MsProbe.config_init(self.msprobe_config) - def init_transfer_dock(self, td, mm_td): + def init_transfer_dock(self, td): self.td = td - self.mm_td = mm_td self.empty_cache() + def init_metrics(self, metrics): + self.metrics = metrics + def get_iteration(self): return self.args.iteration def get_consumed_train_samples(self): return self.args.consumed_train_samples + def get_distributed_node_num(self): + return self.parallel_state.get_data_parallel_world_size() + @mstx_timer_decorator def update(self, kl_ctrl=None, skip_actor_log_prob=False): if skip_actor_log_prob: @@ -148,8 +153,6 @@ class ActorHybridWorkerBase(BaseWorker): experience_columns = ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', 'input_ids', 'response_length', 'prompt_length'] - if is_multimodal(): - experience_columns.extend(['attention_mask', 'position_ids']) if self.rl_config.use_integrated_worker: experience_count = ( @@ -166,7 +169,7 @@ class ActorHybridWorkerBase(BaseWorker): learning_rate = None for param_group in self.optimizer.param_groups: learning_rate = param_group['lr'] - ray.get(self.td.update_metrics.remote(key='grpo/lr', value=learning_rate)) + ray.get(self.metrics.update.remote(key='grpo/lr', value=learning_rate)) sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None actor_update_profiler = profiler_start(self.profiler_config, role="actor_update", @@ -193,10 +196,10 @@ class ActorHybridWorkerBase(BaseWorker): self.args.consumed_train_samples += self.megatron_config.global_batch_size // self.rl_config.n_samples_per_prompt self.num_floating_point_operations_so_far += num_floating_point_operations(self.args, self.megatron_config.global_batch_size) - if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0 and self.parallel_state.get_context_parallel_rank() == 0: - ray.get(self.td.update_metrics.remote(value=metrics, cumulate=True)) + if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0 and self.parallel_state.get_context_parallel_rank() == 0: + ray.get(self.metrics.update.remote(value=metrics, cumulate=True)) ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/update", value=[round(time.time(), 4), round(start_time, 4)], cumulate=True @@ -212,7 +215,7 @@ class ActorHybridWorkerBase(BaseWorker): self.sharding_manager.exit_train_mode() sharding_train_interval += (time.time() - start_sharding_exit_train) ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/resharding_to_train", value=[sharding_train_interval], cumulate=True @@ -234,8 +237,6 @@ class ActorHybridWorkerBase(BaseWorker): experience_consumer_stage = 'actor_rollout' experience_columns = ['prompts', 'prompt_length'] - if is_multimodal(): - experience_columns.extend(['input_ids', 'input_ids_length']) experience_count = self.rl_config.actor_rollout_dispatch_size @@ -246,7 +247,7 @@ class ActorHybridWorkerBase(BaseWorker): actor_generate_profiler = profiler_start(self.profiler_config, role="actor_generate", profiler_iteration=self.prof_iteration) MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences') - + start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True) > 0: batch_data, index = self.dispatch_transfer_dock_data( @@ -270,16 +271,12 @@ class ActorHybridWorkerBase(BaseWorker): prompts = truncate_rows(prompts_data, prompt_length_data) prompts_list = [prompt.numpy().tolist() for prompt in prompts] - responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list), extra_info=batch_data) + responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list)) responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id) responses_length = [torch.tensor([len(response)]) for response in responses] - if is_multimodal(): - prompts_data = batch_data['input_ids'][indexes].unbind() - else: - prompts_data = prompts - + prompts_data = prompts prompts = [] for prompt in prompts_data: for _ in range(self.rl_config.n_samples_per_prompt): @@ -294,15 +291,13 @@ class ActorHybridWorkerBase(BaseWorker): 'input_ids': input_ids_list, 'response_length': responses_length } - if is_multimodal(): - outputs['prompt_length'] = batch_data['input_ids_length'] self.collect_transfer_dock_data(outputs, index, use_vllm=True) end_time = time.time() MsProbe.save_data({"responses": responses, "prompts": prompts}) ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/rollout", value=[round(end_time, 4), round(start_time, 4)], cumulate=True @@ -315,7 +310,7 @@ class ActorHybridWorkerBase(BaseWorker): self.sharding_manager.exit_infer_mode() sharding_infer_interval += (time.time() - start_sharding_exit_infer) ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/resharding_to_infer", value=[sharding_infer_interval], cumulate=True @@ -328,8 +323,6 @@ class ActorHybridWorkerBase(BaseWorker): experience_consumer_stage = 'actor_log_prob' experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] - if is_multimodal(): - experience_columns.extend(['attention_mask', 'position_ids', 'input_ids_length']) experience_count = self.rl_config.actor_logprob_dispatch_size sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None @@ -362,14 +355,14 @@ class ActorHybridWorkerBase(BaseWorker): self.collect_transfer_dock_data(output, index) end_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/old_log_p", value=[round(end_time, 4), round(start_time, 4)], cumulate=True ) ) ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "end_time/old_log_p", value=[round(end_time, 4)], cumulate=True diff --git a/mindspeed_rl/workers/base_worker.py b/mindspeed_rl/workers/base_worker.py index c05a3e7c..01d4b267 100644 --- a/mindspeed_rl/workers/base_worker.py +++ b/mindspeed_rl/workers/base_worker.py @@ -35,9 +35,8 @@ from mindspeed_rl.trainer.utils.parallel_state import ( ) from mindspeed_rl.utils.compute import set_parallel_state, set_vocab_parallel from mindspeed_rl.utils.utils import get_current_dp_range_indexes -from mindspeed_rl.trainer.utils.transfer_dock import pack_experience_columns, unpack_pad_experience -from mindspeed_rl.trainer.utils.mm_transfer_dock import unpack_mm_experience -from mindspeed_rl.utils.utils import mstx_timer_decorator, is_multimodal +from mindspeed_rl.trainer.utils.transfer_dock import get_experience, put_experience, pack_experience_columns, unpack_pad_experience +from mindspeed_rl.utils.utils import mstx_timer_decorator logger = Loggers("base_worker") @@ -132,6 +131,7 @@ class BaseWorker(BaseRayWorker, ABC): megatron_config: MegatronConfig = None, rl_config: RLConfig = None, generate_config: GenerateConfig = None, + actor_fwd_config: MegatronConfig = None, model_provider: Callable = None, initialize_func: Callable = None, get_megatron_module: Callable = None, @@ -145,6 +145,7 @@ class BaseWorker(BaseRayWorker, ABC): self.rl_config = rl_config self.megatron_config = megatron_config self.generate_config = generate_config + self.actor_fwd_config = actor_fwd_config self.profiler_config = profiler_config self.msprobe_config = msprobe_config self.model_provider = model_provider @@ -160,7 +161,6 @@ class BaseWorker(BaseRayWorker, ABC): self.model_type = None self.model = None self.td = None - self.mm_td = None self.args = None @mstx_timer_decorator @@ -177,11 +177,11 @@ class BaseWorker(BaseRayWorker, ABC): rank_flg = False if not use_vllm: - rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) else: - rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) if rank_flg: status = torch.tensor(int(not ray.get(self.td.all_consumed.remote(experience_consumer_stage))), @@ -234,19 +234,6 @@ class BaseWorker(BaseRayWorker, ABC): def td(self, value): self._td = value - @property - def mm_td(self): - """ - worker需要设置td(数据队列)后才可以使用,这里添加判断 - """ - if self._mm_td is None: - raise ValueError("MultiModal Transfer Dock is not initialized") - return self._mm_td - - @mm_td.setter - def mm_td(self, value): - self._mm_td = value - @mstx_timer_decorator def empty_cache(self): """Clear GPU cache (can be overridden by subclasses)""" @@ -258,34 +245,26 @@ class BaseWorker(BaseRayWorker, ABC): use_vllm=False, indexes=None, get_n_samples=True): pad_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod - if is_multimodal(): - mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage)) - else: - mm_columns = [] batch_data = {} batch_data_length = {} - batch_mm_data = {} # make sure that all ranks in cp/tp/pp group enter dispatch_transfer_dock_data, # in case of rank0 get_experience before other ranks judge td.all_consumed rank_flg = False if not use_vllm: - rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and get_context_parallel_rank(self.parallel_state, use_vllm) == 0 and get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) else: - rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and + rank_flg = (get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0 and get_pipeline_model_parallel_rank(self.parallel_state, use_vllm) == 0) - + if rank_flg: - batch_data, index = ray.get(self.td.get_experience.remote(experience_consumer_stage, experience_columns, - experience_count, indexes=indexes, - get_n_samples=get_n_samples)) # cpu数据 + batch_data, index = get_experience(self.td, experience_consumer_stage, experience_count, + indexes=indexes, get_n_samples=get_n_samples) # cpu数据 if not index: # 判断是否取出数据,未取出数据为-1 index = [-1] * experience_count - elif is_multimodal(): - batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index, get_n_samples)) index = torch.tensor(index + ([-1] * (experience_count - len(index)))).cuda() else: @@ -326,17 +305,10 @@ class BaseWorker(BaseRayWorker, ABC): else: batch_data_dtype = torch.tensor(2, dtype=torch.int64, device=torch.cuda.current_device()) - - # 添加维度信息 - if key not in batch_data.keys(): - raise KeyError(f'{key} is missing!') - batch_data_ndim = torch.tensor(len(batch_data[key].shape), - dtype=torch.int64, device=torch.cuda.current_device()) else: - batch_data_shape = torch.empty(2, device=torch.cuda.current_device(), dtype=torch.int64) # 最多支持二维张量 + batch_data_shape = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) batch_data_dtype = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) batch_data_length_shape = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) - batch_data_ndim = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) # TP domain sync torch.distributed.broadcast( @@ -349,10 +321,7 @@ class BaseWorker(BaseRayWorker, ABC): ) torch.distributed.broadcast(batch_data_length_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_tensor_model_parallel_group(self.parallel_state, use_vllm)) - torch.distributed.broadcast( - batch_data_ndim, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) - ) + # CP domain sync if not use_vllm: torch.distributed.broadcast( @@ -365,10 +334,7 @@ class BaseWorker(BaseRayWorker, ABC): ) torch.distributed.broadcast(batch_data_length_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm), group=get_context_parallel_group(self.parallel_state, use_vllm)) - torch.distributed.broadcast( - batch_data_ndim, get_context_parallel_src_rank(self.parallel_state, use_vllm), - group=get_context_parallel_group(self.parallel_state, use_vllm) - ) + # PP domain sync torch.distributed.broadcast( batch_data_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), @@ -380,30 +346,17 @@ class BaseWorker(BaseRayWorker, ABC): ) torch.distributed.broadcast(batch_data_length_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm)) - torch.distributed.broadcast( - batch_data_ndim, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) - ) + if not rank_flg: - if batch_data_ndim == 1: # 一维张量处理 - if batch_data_dtype == 1: - batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.int32) - else: - batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.float32) - else: # 二维张量处理 - if batch_data_dtype == 1: - batch_data[key] = torch.empty(batch_data_shape[0], batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.int32) - else: - batch_data[key] = torch.empty(batch_data_shape[0], batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.float32) + if batch_data_dtype == 1: + batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1], + device=torch.cuda.current_device(), + dtype=torch.int32) + else: + batch_data[key] = torch.empty(batch_data_shape[0], # batch_data_shape[1], + device=torch.cuda.current_device(), + dtype=torch.float32) batch_data_length[key] = torch.empty(batch_data_length_shape[0], device=torch.cuda.current_device(), dtype=torch.int32) # 传输tensor数据 @@ -417,7 +370,7 @@ class BaseWorker(BaseRayWorker, ABC): batch_data[key].cuda(), get_context_parallel_src_rank(self.parallel_state, use_vllm), group=get_context_parallel_group(self.parallel_state, use_vllm) ) - + torch.distributed.broadcast( batch_data[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) @@ -435,17 +388,10 @@ class BaseWorker(BaseRayWorker, ABC): index_without_pad = index.cpu().numpy().tolist()[:batch_data_shape[0]] - if len(mm_columns) > 0: - batch_mm_data = self.get_batch_mm_data(batch_mm_data, mm_columns, rank_flg, use_vllm) - if batch_data: - if is_multimodal(): - padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, 1) - batch_mm_data = unpack_mm_experience(batch_mm_data) - padded_batch_data.update(batch_mm_data) - else: + if cp_algo in ['ulysses_cp_algo']: padded_batch_data = unpack_pad_experience(batch_data, batch_data_length, pad_id, tp_size * cp_size) - + return padded_batch_data, index_without_pad else: return {}, [] @@ -455,14 +401,7 @@ class BaseWorker(BaseRayWorker, ABC): if is_pipeline_last_stage(self.parallel_state, use_vllm) and get_tensor_model_parallel_rank(self.parallel_state, use_vllm) == 0: output = {key: value.cpu() if not isinstance(value, List) else value for key, value in output.items()} - self.td.put_experience.remote(data_dict=output, indexes=index) - - @mstx_timer_decorator - def collect_transfer_dock_mm_data(self, output, index, use_vllm=False): - if is_pipeline_last_stage(self.parallel_state, use_vllm) and get_tensor_model_parallel_rank(self.parallel_state, - use_vllm) == 0: - output = {key: value.cpu() if not isinstance(value, List) else value for key, value in output.items()} - self.mm_td.put_experience.remote(batch=output, indexes=index) + put_experience(self.td, data_dict=output, indexes=index) def get_dp_range_indexes(self, experience_count, use_vllm=False): if use_vllm: @@ -487,77 +426,3 @@ class BaseWorker(BaseRayWorker, ABC): return current_dp_rank, len(vllm_dp_groups) - def get_batch_mm_data(self, batch_mm_data, mm_columns, rank_flg, use_vllm): - for key in mm_columns: - if rank_flg: - if key not in batch_mm_data.keys(): - raise KeyError(f'{key} is missing!') - batch_data_shape = torch.tensor( - batch_mm_data[key].shape, dtype=torch.int64, device=torch.cuda.current_device()) - - if batch_mm_data[key].dtype == torch.int64: - batch_data_dtype = torch.tensor( - 1, dtype=torch.int64, device=torch.cuda.current_device()) - elif batch_mm_data[key].dtype == torch.bfloat16: - batch_data_dtype = torch.tensor( - 2, dtype=torch.int64, device=torch.cuda.current_device()) - else: - batch_data_dtype = torch.tensor( - 3, dtype=torch.int64, device=torch.cuda.current_device()) - else: - batch_data_shape = torch.empty(2, device=torch.cuda.current_device(), dtype=torch.int64) - batch_data_dtype = torch.empty(1, device=torch.cuda.current_device(), dtype=torch.int64) - - # TP domain sync - torch.distributed.broadcast( - batch_data_shape, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) - ) - torch.distributed.broadcast( - batch_data_dtype, get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) - ) - # CP domain sync - if not use_vllm: - torch.distributed.broadcast( - batch_data_shape, get_context_parallel_src_rank(self.parallel_state, use_vllm), - group=get_context_parallel_group(self.parallel_state, use_vllm) - ) - torch.distributed.broadcast( - batch_data_dtype, get_context_parallel_src_rank(self.parallel_state, use_vllm), - group=get_context_parallel_group(self.parallel_state, use_vllm) - ) - # PP domain sync - torch.distributed.broadcast( - batch_data_shape, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) - ) - torch.distributed.broadcast( - batch_data_dtype, get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) - ) - - if not rank_flg: - if batch_data_dtype == 1: - batch_mm_data[key] = torch.empty(batch_data_shape[0], batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.int64) - elif batch_data_dtype == 2: - batch_mm_data[key] = torch.empty(batch_data_shape[0], batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.bfloat16) - else: - batch_mm_data[key] = torch.empty(batch_data_shape[0], batch_data_shape[1], - device=torch.cuda.current_device(), - dtype=torch.float32) - - # 传输tensor数据 - torch.distributed.broadcast( - batch_mm_data[key].cuda(), get_tensor_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_tensor_model_parallel_group(self.parallel_state, use_vllm) - ) - torch.distributed.broadcast( - batch_mm_data[key].cuda(), get_pipeline_model_parallel_src_rank(self.parallel_state, use_vllm), - group=get_pipeline_model_parallel_group(self.parallel_state, use_vllm) - ) - return batch_mm_data \ No newline at end of file diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index 175ee11e..40ff1fbe 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -119,7 +119,7 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB self.ref_manager.onload_param() end_onload_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/ref_onload", value=[round(end_onload_time, 4), round(start_onload_time, 4)], cumulate=True @@ -143,7 +143,7 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB self.ref_manager.offload_param() end_offload_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/ref_offload", value=[round(end_offload_time, 4), round(start_offload_time, 4)], cumulate=True diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 365d209d..e6efa218 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -9,7 +9,7 @@ import torch from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.config_cls.rl_config import RLConfig -from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.config_cls.generate_config import GenerateConfig from mindspeed_rl.models.reference import Reference from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.tokenizer import BaseTokenizer @@ -17,7 +17,7 @@ from mindspeed_rl.workers.base_worker import BaseWorker from mindspeed_rl.utils.compute import get_parallel_state from mindspeed_rl.trainer.utils.parallel_state import is_pipeline_last_stage, get_tensor_model_parallel_rank, get_context_parallel_rank from mindspeed_rl.utils.loggers import Loggers -from mindspeed_rl.utils.utils import mstx_timer_decorator, is_multimodal +from mindspeed_rl.utils.utils import mstx_timer_decorator logger = Loggers(__name__) @@ -70,7 +70,7 @@ class ReferenceWorkerBase(BaseWorker): else: self.megatron_config.iteration = 0 self.megatron_config.num_floating_point_operations_so_far = 0 - + megatron_module = self.get_megatron_module() self.reference = Reference( @@ -92,16 +92,16 @@ class ReferenceWorkerBase(BaseWorker): temperature=self.generate_config.sampling_config["temperature"] ) - def init_transfer_dock(self, td, mm_td): + def init_transfer_dock(self, td): self.td = td - self.mm_td = mm_td + + def init_metrics(self, metrics): + self.metrics = metrics @mstx_timer_decorator def compute_ref_log_prob(self): experience_consumer_stage = 'ref_log_prob' experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length'] - if is_multimodal(): - experience_columns.extend(['attention_mask', 'position_ids', 'input_ids_length']) experience_count = self.rl_config.ref_dispatch_size sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None @@ -122,7 +122,7 @@ class ReferenceWorkerBase(BaseWorker): start_time = time.time() start_time_defined = True ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "start_time/reference_model", value=[round(start_time, 4)], cumulate=True @@ -140,7 +140,7 @@ class ReferenceWorkerBase(BaseWorker): self.collect_transfer_dock_data(output, index) end_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/reference_model", value=[round(end_time, 4), round(start_time, 4)], cumulate=True @@ -152,7 +152,7 @@ class ReferenceWorkerBase(BaseWorker): if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0: ref_end_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "end_time/reference", value=[round(ref_end_time, 4)] ) diff --git a/mindspeed_rl/workers/reward_woker.py b/mindspeed_rl/workers/reward_woker.py index 0b3ea2fc..83539699 100644 --- a/mindspeed_rl/workers/reward_woker.py +++ b/mindspeed_rl/workers/reward_woker.py @@ -86,6 +86,9 @@ class RewardWorkerBase(BaseWorker): def init_transfer_dock(self, td): self.td = td + def init_metrics(self, metrics): + self.metrics = metrics + def compute_rm_score(self): experience_consumer_stage = 'reward_scores' experience_columns = ['input_ids', 'prompt_length', "responses", "response_length", @@ -117,7 +120,7 @@ class RewardWorkerBase(BaseWorker): start_time = time.time() start_time_defined = True ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "start_time/reward_model", value=[round(start_time, 4)], cumulate=True @@ -137,7 +140,7 @@ class RewardWorkerBase(BaseWorker): self.collect_transfer_dock_data(output, index) end_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "timing/reward_model", value=[round(end_time, 4), round(start_time, 4)], cumulate=True @@ -148,7 +151,7 @@ class RewardWorkerBase(BaseWorker): if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0 and self.parallel_state.get_context_parallel_rank() == 0: rwd_end_time = time.time() ray.get( - self.td.update_metrics.remote( + self.metrics.update.remote( "end_time/reward_model", value=[round(rwd_end_time, 4)] ) diff --git a/mindspeed_rl/workers/rule_reward.py b/mindspeed_rl/workers/rule_reward.py index a68dba58..57e26a07 100644 --- a/mindspeed_rl/workers/rule_reward.py +++ b/mindspeed_rl/workers/rule_reward.py @@ -7,6 +7,7 @@ from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compu from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.trainer.utils.transfer_dock import pad_experience from mindspeed_rl.utils.utils import get_current_dp_range_indexes, is_multimodal +from mindspeed_rl.trainer.utils.transfer_dock import get_experience, put_experience logger = Loggers("rule_reward") @@ -26,9 +27,11 @@ class RuleReward(object): self.td = td self.mm_td = mm_td + def init_metrics(self, metrics): + self.metrics = metrics + def compute_rm_score(self): experience_consumer_stage = 'rule_reward' - experience_columns = ['prompts', 'responses', 'response_length', *self.megatron_config.dataset_additional_keys] experience_count = self.rl_config.reward_dispatch_size assign_batch_size = self.megatron_config.global_batch_size * self.rl_config.n_samples_per_prompt sorted_indexes = get_current_dp_range_indexes(experience_count=experience_count, @@ -36,14 +39,9 @@ class RuleReward(object): pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod while not ray.get(self.td.all_consumed.remote(experience_consumer_stage)): - batch_data, index = ray.get( - self.td.get_experience.remote( - experience_consumer_stage, - experience_columns, - experience_count, - indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None - ) - ) # cpu数据 + batch_data, index = get_experience(self.td, experience_consumer_stage, experience_count, + indexes=sorted_indexes.pop( + 0) if self.rl_config.guarantee_order else None) if batch_data and index: batch_data = pad_experience(batch_data, pad_token_id) # multiple, tp_size @@ -63,7 +61,7 @@ class RuleReward(object): self.hf_tokenizer, ignore_token) for key, value in metrics.items(): - ray.get(self.td.update_metrics.remote(key, value=value, cumulate=True)) + ray.get(self.metrics.update.remote(key, value=value, cumulate=True)) output = {"rm_scores": token_level_rewards, "token_level_rewards": token_level_rewards} else: @@ -90,4 +88,4 @@ class RuleReward(object): reward_tensor_normalized = (reward_tensor_reshaped - reward_mean) / reward_std reward_tensor = reward_tensor_normalized.reshape(original_shape) output = {"rm_scores": rm_scores, "token_level_rewards": reward_tensor} - self.td.put_experience.remote(data_dict=output, indexes=index) + put_experience(self.td, data_dict=output, indexes=index) diff --git a/mindspeed_rl/workers/scheduler/launcher.py b/mindspeed_rl/workers/scheduler/launcher.py index 9374090a..85519207 100644 --- a/mindspeed_rl/workers/scheduler/launcher.py +++ b/mindspeed_rl/workers/scheduler/launcher.py @@ -246,13 +246,13 @@ class RayActorGroup: def execute_sync_command(self, method_name: str, *args, **kwargs): return ray.get(self.execute_async_command(method_name, *args, **kwargs)) - def async_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None): + def async_init_transfer_dock(self, transfer_dock): for actor in self.actor_handlers: - self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock)) + self.temp_actor_ref_objs.append(actor.init_transfer_dock.remote(transfer_dock)) - def sync_init_transfer_dock(self, transfer_dock, mm_transfer_dock=None): + def sync_init_transfer_dock(self, transfer_dock): for actor in self.actor_handlers: - ray.get(actor.init_transfer_dock.remote(transfer_dock, mm_transfer_dock)) + ray.get(actor.init_transfer_dock.remote(transfer_dock)) def wait_all_ref_objs_run_over(self): ray.get(self.temp_actor_ref_objs) @@ -304,3 +304,10 @@ class RayActorGroup: def get_consumed_train_samples(self): return ray.get(self.actor_handlers[0].get_consumed_train_samples.remote()) + + def get_distributed_node_num(self): + return ray.get(self.actor_handlers[0].get_distributed_node_num.remote()) + + def sync_init_metrics(self, metrics): + for actor in self.actor_handlers: + ray.get(actor.init_metrics.remote(metrics)) -- Gitee