代码拉取完成,页面将自动刷新
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
import time
import dataclasses
import copy
import gc
from enum import Enum
from typing import Callable
import logging as logger
import numpy as np
import ray
import torch
from transformers import AutoConfig
from mindspeed_rl.config_cls.megatron_config import MegatronConfig
from mindspeed_rl.config_cls.rl_config import RLConfig
from mindspeed_rl.config_cls.generate_config import GenerateConfig
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig
from mindspeed_rl.models.actor_rollout_hybrid import ActorRolloutHybrid
from mindspeed_rl.models.rollout.vllm_engine import VLLMInferEngine
from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.utils.utils import MsProbe
from mindspeed_rl.workers.base_worker import BaseWorker
from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronShardingManager
from mindspeed_rl.workers.resharding.megatron_off_loader import MegatronOffLoader
from mindspeed_rl.utils.utils import (num_floating_point_operations, get_attr_wrapped_model, mstx_timer_decorator,
profiler_start, profiler_step, is_multimodal, replace_torch_compile)
from mindspeed_rl.utils.pad_process import remove_padding_and_split_to_list, truncate_rows
from mindspeed_rl.utils.zmq_communication import (ZmqServer, ZmqClient, ZmqServerInfo, ZmqClientInfo,
ZMQ_ROLE_SERVER, ZMQ_ROLE_CLIENT)
from mindspeed_rl.models.rollout.vllm_adapter.vllm_parallel_state import get_vllm_tp_group_ranks
from mindspeed_rl.trainer.utils.mm_transfer_dock import unpack_mm_experience
class ActorState(Enum):
NONE = "none"
INFER = "infer"
class ActorHybridWorkerBase(BaseWorker):
"""
ActorHybridWorker class. This class implements the hybrid worker logic for training and inference.
Args:
megatron_config: MegatronConfig Configuration for Megatron-LM (e.g., model parallelism settings).
rl_config: RLConfig Configuration for reinforcement learning (e.g., PPO settings).
generate_config: GenerateConfig Configuration for generation/inference (e.g., vLLM settings).
model_provider: Callable Function to provide the model instance.
initialize_func: Callable Function to initialize the model and environment.
tokenizer: BaseTokenizer = None Object to retrieve the tokenizer.
get_megatron_module: Callable = megatron_module from get_megatron_module.
profiler_config: ProfilerConfig, Configuration for profiling.
msprobe_config: MsprobeConfig, Configuration for msprobe.
**kwargs: Additional parameters for base class argument passing.
"""
def __init__(
self,
megatron_config: MegatronConfig,
rl_config: RLConfig,
generate_config: GenerateConfig,
model_provider: Callable,
initialize_func: Callable,
tokenizer: BaseTokenizer = None,
get_megatron_module: Callable = None,
profiler_config: ProfilerConfig = None,
msprobe_config: MsprobeConfig = None,
**kwargs
):
super().__init__(
megatron_config,
rl_config,
generate_config,
model_provider=model_provider,
initialize_func=initialize_func,
tokenizer=tokenizer,
get_megatron_module=get_megatron_module,
profiler_config=profiler_config,
msprobe_config=msprobe_config,
**kwargs
)
self.num_floating_point_operations_so_far = 0
self.actor_hybrid = None
self.actor_offloader = None
self.state = ActorState.NONE
self.actor_profiler = None
self.prof_iteration = 1
self.idx = 0
self.enable_partial_rollout = self.rl_config.partial_rollout_max_split > 1
def initialize(self):
self.setup_distributed_rank()
self.model, self.optimizer, self.opt_param_scheduler = self._build_model_optimizer()
self._set_no_sync_func()
self.actor_offloader = MegatronOffLoader(
self.model,
self.optimizer,
megatron_config=self.megatron_config,
distributed_optimizer=self.distributed_optimizer,
float16_optimizer_with_float16_params=self.float16_optimizer_with_float16_params)
if self.generate_config.offload_train_optimizer:
self.actor_offloader.offload_optimizer()
if self.generate_config.offload_train_grad:
self.actor_offloader.offload_grad()
if self.generate_config.offload_train_param:
self.actor_offloader.offload_param()
with replace_torch_compile():
self.inference_model = self._build_rollout()
self.sharding_manager = self._build_sharding_manager()
if self.generate_config.offload_train_param:
self.actor_offloader.onload_param()
self.actor_hybrid = ActorRolloutHybrid(
self.model,
megatron_config=self.megatron_config,
optimizer=self.optimizer,
opt_param_scheduler=self.opt_param_scheduler,
inference_model=self.inference_model,
sharding_manager=self.sharding_manager,
beta=self.rl_config.beta,
mini_batch_size_per_dp=self.rl_config.mini_batch_size
// self.parallel_state.get_data_parallel_world_size(),
epochs=self.rl_config.epochs,
shuffle_mini_batch=self.rl_config.shuffle_mini_batch,
generate_config=self.generate_config,
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
clip_ratio=self.rl_config.clip_ratio,
micro_batch_size=self.megatron_config.micro_batch_size,
entropy_coeff=self.rl_config.entropy_coeff,
kl_penalty=self.rl_config.kl_penalty,
temperature=self.generate_config.sampling_config["temperature"],
token_level_loss=self.rl_config.token_level_loss,
clip_higher_enable=self.rl_config.clip_higher_enable,
clip_ratio_low=self.rl_config.clip_ratio_low,
clip_ratio_high=self.rl_config.clip_ratio_high,
reuse_image_embeds=self.rl_config.reuse_image_embeds,
use_remove_padding=self.rl_config.use_remove_padding,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
actor_max_packing_token_size=self.rl_config.actor_max_packing_token_size,
update_max_packing_token_size=self.rl_config.update_max_packing_token_size,
actor_dynamic_max_batch_size=self.rl_config.actor_dynamic_max_batch_size,
update_dynamic_max_batch_size=self.rl_config.update_dynamic_max_batch_size,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size
)
self.empty_cache()
self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
MsProbe.config_init(self.msprobe_config)
def init_transfer_dock(self, td, mm_td=None, sampling_transfer_dock=None, mm_sampling_transfer_dock=None):
self.td = td
self.mm_td = mm_td
self.sampling_transfer_dock = sampling_transfer_dock
self.mm_sampling_transfer_dock = mm_sampling_transfer_dock
self.empty_cache()
def get_iteration(self):
return self.args.iteration
def get_consumed_train_samples(self):
return self.args.consumed_train_samples
def enter_infer_mode(self):
if self.state == ActorState.INFER:
return
start_time = time.time()
self.sharding_manager.enter_infer_mode()
self.state = ActorState.INFER
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/resharding_enter_infer",
value=[end_time - start_time],
cumulate=True
)
)
def exit_infer_mode(self):
if self.state != ActorState.INFER:
raise RuntimeError
start_time = time.time()
self.sharding_manager.exit_infer_mode()
self.state = ActorState.NONE
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/resharding_exit_infer",
value=[end_time - start_time],
cumulate=True
)
)
@mstx_timer_decorator
def update(self, kl_ctrl=None, skip_actor_log_prob=False):
start_sharding_enter_train = time.time()
self.sharding_manager.enter_train_mode()
sharding_train_interval = time.time() - start_sharding_enter_train
self.args.curr_iteration = self.iteration
experience_consumer_stage = 'actor_train'
if self.megatron_config.stage == "ray_dapo":
experience_columns = ['responses', 'advantages', 'old_log_prob', 'input_ids', 'response_length', 'prompt_length']
else:
experience_columns = ['responses', 'advantages', 'old_log_prob', 'ref_log_prob', 'input_ids', 'response_length', 'prompt_length']
if is_multimodal():
experience_columns.extend(['attention_mask', 'position_ids'])
experience_count = self.rl_config.actor_update_dispatch_size
else:
experience_count = (
self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size()
)
if self.rl_config.filter_groups_enable:
experience_count = (
self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt //
self.parallel_state.get_data_parallel_world_size()
)
if skip_actor_log_prob:
experience_columns.remove('old_log_prob')
#get lr
learning_rate = None
for param_group in self.optimizer.param_groups:
learning_rate = param_group['lr']
ray.get(self.td.update_metrics.remote(key='actor/lr', value=learning_rate))
sorted_indexes = self.get_dp_range_indexes(
experience_count,
use_vllm=False
) if self.rl_config.guarantee_order else None
actor_update_profiler = profiler_start(
self.profiler_config,
role="actor_update",
profiler_iteration=self.prof_iteration
)
MsProbe.debugger_start(self.model[0], tag='actor_update')
start_time_defined = False
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
experience_columns,
experience_count,
self.megatron_config.tensor_model_parallel_size,
self.megatron_config.context_parallel_size,
self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=self.enable_partial_rollout)
if not start_time_defined:
start_time = time.time()
start_time_defined = True
if batch_data and index:
metrics = self.actor_hybrid.update_actor(batch_data, kl_ctrl)
self.args.consumed_train_samples += self.megatron_config.global_batch_size // self.rl_config.n_samples_per_prompt
self.num_floating_point_operations_so_far += num_floating_point_operations(self.args,
self.megatron_config.global_batch_size)
if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True) and self.parallel_state.get_tensor_model_parallel_rank() == 0 and self.parallel_state.get_context_parallel_rank() == 0:
ray.get(self.td.update_metrics.remote(value=metrics, cumulate=True))
ray.get(
self.td.update_metrics.remote(
"timing/update",
value=[round(time.time(), 4), round(start_time, 4)],
cumulate=True
)
)
self.iteration += 1
profiler_step(actor_update_profiler)
MsProbe.debugger_stop(tag='actor_update')
MsProbe.step()
self.prof_iteration += 1
start_sharding_exit_train = time.time()
self.sharding_manager.exit_train_mode()
sharding_train_interval += (time.time() - start_sharding_exit_train)
ray.get(
self.td.update_metrics.remote(
"timing/resharding_to_train",
value=[sharding_train_interval],
cumulate=True
)
)
profiler_step(self.actor_profiler)
logger.info("finish actor update")
def save_ckpt(self, iteration: int):
self.sharding_manager.enter_train_mode()
self.save_checkpoint(iteration, self.model, self.optimizer, self.opt_param_scheduler,
self.num_floating_point_operations_so_far)
self.sharding_manager.exit_train_mode()
self.empty_cache()
def get_partial_rollout_stop_signal(self):
if not self.enable_partial_rollout:
return False
td = self.sampling_transfer_dock if self.sampling_transfer_dock else self.td
return ray.get(
td.get_update_ready.remote(require_max_age_all_finished=self.rl_config.require_max_age_all_finished))
@mstx_timer_decorator
def generate_sequences(self):
sharding_infer_interval = 0
if not self.rl_config.filter_groups_enable:
start_sharding_enter_infer = time.time()
self.sharding_manager.enter_infer_mode()
sharding_infer_interval = time.time() - start_sharding_enter_infer
experience_consumer_stage = 'actor_rollout'
experience_columns = ['prompts', 'prompt_length']
if is_multimodal():
experience_columns.extend(['input_ids', 'input_ids_length'])
if self.enable_partial_rollout:
experience_columns.extend(['responses', 'response_length', 'age'])
if self.enable_partial_rollout and (self.rl_config.async_engine or self.iteration == self.megatron_config.train_iters - 1):
td = self.sampling_transfer_dock if self.sampling_transfer_dock else self.td
incomplete_resp_num = ray.get(td.get_incomplete_response_num.remote())
experience_count = int(np.ceil(incomplete_resp_num / self.generate_config.data_parallel_size))
else:
experience_count = self.rl_config.actor_rollout_dispatch_size
pad_token_id = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod
sorted_indexes = self.get_dp_range_indexes(experience_count,
use_vllm=True) if self.rl_config.guarantee_order else None
actor_generate_profiler = profiler_start(self.profiler_config, role="actor_generate",
profiler_iteration=self.prof_iteration)
MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences')
start_time = time.time()
while self.all_consumed(experience_consumer_stage, sorted_indexes, use_vllm=True, is_generate=True) > 0:
batch_data, index = self.dispatch_transfer_dock_data(
experience_consumer_stage,
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
cp_size=self.megatron_config.context_parallel_size,
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
use_vllm=True,
get_n_samples=not self.enable_partial_rollout,
enable_partial_rollout=self.enable_partial_rollout,
is_generate=True
)
if batch_data and index:
if self.rl_config.async_engine:
logger.info(f"do async generate process.")
self.async_generate_process(batch_data, index, pad_token_id)
else:
self.sync_generate_process(batch_data, experience_count, index, pad_token_id)
if self.enable_partial_rollout:
torch.distributed.barrier()
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/rollout",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
profiler_step(actor_generate_profiler)
MsProbe.debugger_stop('actor_generate_sequences')
self.idx += 1
if not self.rl_config.filter_groups_enable:
start_sharding_exit_infer = time.time()
self.sharding_manager.exit_infer_mode()
torch.cuda.empty_cache()
sharding_infer_interval += (time.time() - start_sharding_exit_infer)
ray.get(
self.td.update_metrics.remote(
"timing/resharding_to_infer",
value=[sharding_infer_interval],
cumulate=True
)
)
logger.info("finish generate_sequences")
def sync_generate_process(self, batch_data, experience_count, index, pad_token_id):
if not self.enable_partial_rollout:
indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt))
if self.rl_config.reuse_image_embeds:
prompts_data = batch_data['input_ids'][indexes]
prompt_length_data = batch_data['input_ids_length'][indexes]
# preprocess, remove padding
prompts = truncate_rows(prompts_data, prompt_length_data, left_pad=True)
else:
prompts_data = batch_data['prompts'][indexes]
prompt_length_data = batch_data['prompt_length'][indexes]
# preprocess, remove padding
prompts = truncate_rows(prompts_data, prompt_length_data)
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
else:
prompts_data = batch_data['prompts']
prompt_length_data = batch_data['prompt_length']
responses = batch_data['responses']
responses_length_partial = batch_data['response_length']
responses_partial = truncate_rows(responses, responses_length_partial)
prompts = truncate_rows(prompts_data, prompt_length_data)
prompts_for_vllm = [torch.cat(
(prompt, response), dim=0) for prompt, response in
zip(prompts, responses_partial)]
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
if self.enable_partial_rollout:
max_tokens = self.generate_config.sampling_config["max_tokens"] // self.rl_config.partial_rollout_max_split
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
max_tokens=max_tokens, n=1,
extra_info=batch_data)
else:
responses_pad_right = self.actor_hybrid.generate_sequences(copy.deepcopy(prompts_list),
extra_info=batch_data)
responses = remove_padding_and_split_to_list(responses_pad_right, self.tokenizer.eod, pad_token_id)
if is_multimodal():
prompts_data = batch_data['input_ids'][indexes].cpu().unbind()
else:
prompts_data = prompts
if self.enable_partial_rollout:
new_responses = []
for response_partial, response in zip(responses_partial, responses):
new_resp = torch.cat((response_partial, response), dim=0)
test_resp = new_resp >= self.tokenizer.vocab_size
if test_resp.sum() > 0:
new_resp[test_resp] = 0
new_responses.append(new_resp)
responses = new_responses
else:
prompts = []
for prompt in prompts_data:
for _ in range(self.rl_config.n_samples_per_prompt):
prompts.append(copy.deepcopy(prompt))
responses_length = [torch.tensor([len(response)]) for response in responses]
input_ids_list = []
for prompt, response in zip(prompts, responses):
input_ids_list.append(torch.cat((prompt, response), dim=0))
outputs = {
'responses': responses,
'input_ids': input_ids_list,
'response_length': responses_length
}
if is_multimodal():
outputs['prompt_length'] = batch_data['input_ids_length']
if self.enable_partial_rollout:
finish_status = [torch.tensor([0])] * len(responses_length)
for idx, _ in enumerate(responses):
if responses[idx][-1] == self.tokenizer.eod or \
(prompt_length_data[idx][0] + responses_length[
idx][0] >= self.generate_config.max_model_len) or responses_length[
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
finish_status[idx] = torch.tensor([1])
outputs["rollout_completed"] = finish_status
self.collect_transfer_dock_data(outputs, index, use_vllm=True, is_generate=True, sync=is_multimodal())
MsProbe.save_data({"responses": responses, "prompts": prompts})
def async_generate_process(self, batch_data, index, pad_token_id):
self.actor_hybrid.inference_actor.init_cache_engine()
prompts_data = batch_data['prompts']
prompt_length_data = batch_data['prompt_length']
prompts = truncate_rows(prompts_data, prompt_length_data)
if self.enable_partial_rollout:
responses = batch_data['responses']
responses_length_partial = batch_data['response_length']
responses_partial = truncate_rows(responses, responses_length_partial)
prompts_for_vllm = [torch.cat((prompt, response), dim=0) for prompt, response in zip(prompts, responses_partial)]
prompts_list = [prompt.numpy().tolist() for prompt in prompts_for_vllm]
else:
prompts_list = [prompt.numpy().tolist() for prompt in prompts]
if self.enable_partial_rollout:
response_generator = self.actor_hybrid.generate_sequences(
copy.deepcopy(prompts_list),
indexes=index,
n=1,
async_engine=True,
stop_singal_func=self.get_partial_rollout_stop_signal,
)
else:
response_generator = self.actor_hybrid.generate_sequences(
copy.deepcopy(prompts_list),
indexes=index,
n=1,
async_engine=True,
)
for samples, idx_output in response_generator:
prompts, responses, log_probs = samples
responses = remove_padding_and_split_to_list(responses, self.tokenizer.eod, pad_token_id)
remove_input_ids = False
if self.enable_partial_rollout and len(responses[0]) == 1:
remove_input_ids = True
if self.enable_partial_rollout:
responses_partial_new = []
prompt_length_new = []
for idx in range(len(responses)):
iidx = index.index(idx_output[idx])
responses_partial_new.append(responses_partial[iidx])
prompt_length_new.append(prompt_length_data[iidx])
new_responses = []
for response_partial, response in zip(responses_partial_new, responses):
new_resp = torch.cat((response_partial, response), dim=0)
test_resp = new_resp >= self.tokenizer.vocab_size
if test_resp.sum() > 0:
new_resp[test_resp] = 0
new_responses.append(new_resp)
responses = new_responses
responses_length = [torch.tensor([len(response)]) for response in responses]
input_ids_list = []
for prompt, response in zip(prompts, responses):
input_ids_list.append(torch.cat((prompt, response), dim=0))
outputs = {
'responses': responses,
'input_ids': input_ids_list,
'response_length': responses_length
}
if remove_input_ids:
outputs.pop("input_ids")
if self.enable_partial_rollout:
finish_status = [torch.tensor([0])] * len(responses_length)
for idx, _ in enumerate(responses):
if responses[idx][-1] == self.tokenizer.eod or \
prompt_length_new[idx][0].to('cpu') + responses_length[
idx][0] >= self.generate_config.max_model_len or responses_length[
idx][0] >= self.generate_config.sampling_config["max_tokens"]:
finish_status[idx] = torch.tensor([1])
outputs["rollout_completed"] = finish_status
self.collect_transfer_dock_data(outputs, idx_output, use_vllm=True, is_generate=True, sync=is_multimodal())
MsProbe.save_data({"responses": responses, "prompts": prompts})
self.actor_hybrid.inference_actor.free_cache_engine()
@mstx_timer_decorator
def compute_log_prob(self):
self.sharding_manager.enter_forward_mode()
experience_consumer_stage = 'actor_log_prob'
experience_columns = ['input_ids', 'responses', 'response_length', 'prompt_length']
if is_multimodal():
experience_columns.extend(['attention_mask', 'position_ids', 'input_ids_length'])
experience_count = self.rl_config.actor_logprob_dispatch_size
sorted_indexes = self.get_dp_range_indexes(
experience_count,
use_vllm=False,
) if self.rl_config.guarantee_order else None
actor_compute_log_prob_profiler = profiler_start(
self.profiler_config,
role="actor_compute_log_prob",
profiler_iteration=self.prof_iteration
)
MsProbe.debugger_start(self.model[0], tag='actor_compute_log_prob')
start_time_defined = False
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage,
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
cp_size=self.megatron_config.context_parallel_size,
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(
0) if self.rl_config.guarantee_order else None,
get_n_samples=self.enable_partial_rollout)
if not start_time_defined:
start_time = time.time()
start_time_defined = True
if batch_data and index:
output, batch = self.actor_hybrid.compute_log_prob(batch_data)
if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True):
log_probs = torch.cat(output, dim=0)
log_probs = log_probs.to(torch.float32)
log_probs = truncate_rows(log_probs, batch['response_length'])
output = {'old_log_prob': log_probs}
self.collect_transfer_dock_data(output, index)
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/old_log_p",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
ray.get(
self.td.update_metrics.remote(
"end_time/old_log_p",
value=[round(end_time, 4)],
cumulate=True
)
)
self.sharding_manager.exit_forward_mode()
torch.cuda.empty_cache()
profiler_step(actor_compute_log_prob_profiler)
MsProbe.debugger_stop('actor_compute_log_prob')
logger.info("finish compute_log_prob")
@mstx_timer_decorator
def compute_image_embeds(self):
experience_consumer_stage = 'actor_image_embeds'
experience_columns = ['input_ids']
experience_count = self.rl_config.actor_image_embeds_dispatch_size
sorted_indexes = self.get_dp_range_indexes(experience_count,
use_vllm=False) if self.rl_config.guarantee_order else None
actor_image_embeds_profiler = profiler_start(self.profiler_config, role="actor_image_embeds",
profiler_iteration=self.prof_iteration)
MsProbe.debugger_start(self.model[0], tag='actor_image_embeds')
start_time_defined = False
while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0:
batch_data, index = self.dispatch_transfer_dock_data(
experience_consumer_stage,
experience_columns,
experience_count,
tp_size=self.megatron_config.tensor_model_parallel_size,
cp_size=self.megatron_config.context_parallel_size,
cp_algo=self.megatron_config.context_parallel_algo,
indexes=sorted_indexes.pop(0) if self.rl_config.guarantee_order else None,
get_n_samples=True
)
if not start_time_defined:
start_time = time.time()
start_time_defined = True
if batch_data and index:
indexes = list(range(0, experience_count, self.rl_config.n_samples_per_prompt))
batch_data['input_ids'] = batch_data['input_ids'][indexes]
output, batch = self.actor_hybrid.compute_image_embeds(batch_data)
if self.parallel_state.is_pipeline_last_stage(ignore_virtual=True):
output = torch.cat(output, dim=0)
data = {
"vit_embeds": output.squeeze(1).cpu(),
"image_grid_thw": batch_data['image_grid_thw'],
"image_num": batch_data['image_num'],
"video_num": batch_data['video_num']
}
data = unpack_mm_experience(data)
output = {'vit_embeds': data['vit_embeds']}
index = [i // self.rl_config.n_samples_per_prompt for i in index[::self.rl_config.n_samples_per_prompt]]
self.collect_transfer_dock_mm_data(output, index)
end_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/image_embeds",
value=[round(end_time, 4), round(start_time, 4)],
cumulate=True
)
)
ray.get(
self.td.update_metrics.remote(
"end_time/image_embeds",
value=[round(end_time, 4)],
cumulate=True
)
)
profiler_step(actor_image_embeds_profiler)
MsProbe.debugger_stop('actor_image_embeds')
logger.info("finish compute_image_embeds")
def _build_model_optimizer(self):
actor_module, optimizer, opt_param_scheduler = self.setup_model_and_optimizer(
self.model_provider, self.model_type.encoder_or_decoder)
self.iteration = self.get_iteration()
return actor_module, optimizer, opt_param_scheduler
def _build_rollout(self):
self.actor_model_config = AutoConfig.from_pretrained(
self.megatron_config.tokenizer_name_or_path, trust_remote_code=self.generate_config.trust_remote_code)
sampling_config = {"num_completions": self.rl_config.n_samples_per_prompt,
**self.generate_config.sampling_config}
rollout = VLLMInferEngine(
tokenizer_name_or_path=self.megatron_config.tokenizer_name_or_path,
train_tensor_parallel_size=self.megatron_config.tensor_model_parallel_size,
train_pipeline_parallel_size=self.megatron_config.pipeline_model_parallel_size,
train_expert_parallel_size=self.megatron_config.expert_model_parallel_size,
train_context_parallel_size=self.megatron_config.context_parallel_size,
infer_tensor_parallel_size=self.generate_config.infer_tensor_parallel_size,
infer_pipeline_parallel_size=self.generate_config.infer_pipeline_parallel_size,
infer_expert_parallel_size=self.generate_config.infer_expert_parallel_size,
megatron_config=self.megatron_config,
sampling_config=sampling_config,
enable_prefix_caching=self.generate_config.enable_prefix_caching,
num_scheduler_steps=self.generate_config.num_scheduler_steps,
max_num_seqs=self.generate_config.max_num_seqs,
max_model_len=self.generate_config.max_model_len,
dtype=self.generate_config.dtype,
gpu_memory_utilization=self.generate_config.gpu_memory_utilization,
trust_remote_code=self.generate_config.trust_remote_code,
enforce_eager=self.generate_config.enforce_eager,
torchair_graph=self.generate_config.torchair_graph,
ascend_scheduler_config_enabled=self.generate_config.ascend_scheduler_config_enabled,
enable_expert_parallel=self.generate_config.enable_expert_parallel,
expert_map_path=self.generate_config.expert_map_path,
eplb_token_collects=self.generate_config.eplb_token_collects,
eplb_token_save_path=self.generate_config.eplb_token_save_path,
max_num_batched_tokens=self.generate_config.max_num_batched_tokens,
limit_mm_image_per_prompt=self.generate_config.limit_mm_image_per_prompt,
limit_mm_video_per_prompt=self.generate_config.limit_mm_video_per_prompt
)
if self.rl_config.zmq_communication:
# if not include these lines, tp and pp rank is 0
torch.distributed.get_rank()
vllm_dp_groups = get_vllm_tp_group_ranks()
if vllm_dp_groups is None:
raise ValueError("vllm dp groups is None")
from vllm.distributed import parallel_state as vpu
if vpu.get_tensor_model_parallel_rank() == 0 and \
vpu.get_pipeline_model_parallel_group().rank_in_group == 0:
server_info = ZmqServerInfo()
server_info.global_rank = self._rank
server_info.dp_world_size = (vpu.get_tensor_model_parallel_world_size() *
vpu.get_pp_group().world_size)
server_info.ip_addr = ray._private.services.get_node_ip_address().strip("[]")
server_info.register_port = self._get_free_port()
server_info.publisher_port = self._get_free_port()
server_info.reliability_port = self._get_free_port()
server_info.use_vllm = True
self.zmq_role = ZMQ_ROLE_SERVER
self.zmq_server_vllm = ZmqServer(server_info, vpu)
else:
client_info = ZmqClientInfo()
client_info.global_rank = self._rank
client_info.use_vllm = True
self.zmq_role = ZMQ_ROLE_CLIENT
self.zmq_client_vllm = ZmqClient(client_info, vpu)
return rollout
def _build_sharding_manager(self):
# perform weight resharding between actor and rollout
sharding_manager = MegatronShardingManager(
megatron_model=self.model,
model_config=self.actor_model_config,
infer_tensor_parallel_size=self.generate_config.infer_tensor_parallel_size,
infer_pipeline_parallel_size=self.generate_config.infer_pipeline_parallel_size,
infer_expert_parallel_size=self.generate_config.infer_expert_parallel_size,
num_layer_list=self.megatron_config.num_layer_list,
moe_tp_extend_ep=self.megatron_config.moe_tp_extend_ep,
parallel_state=self.parallel_state,
inference_engine=self.inference_model,
optimizer=self.optimizer,
optimizer_offload=self.generate_config.offload_train_optimizer,
grad_offload=self.generate_config.offload_train_grad,
train_param_offload=self.generate_config.offload_train_param,
megatron_offloader=self.actor_offloader,
noop_layers=self.megatron_config.noop_layers
)
return sharding_manager
def _set_no_sync_func(self):
config = get_attr_wrapped_model(self.model[0], 'config', allow_none=False)
config.grad_scale_func = self.optimizer.scale_loss
if isinstance(self.model[0], self.distributed_data_parallel) and self.megatron_config.overlap_grad_reduce:
if config.no_sync_func is not None:
raise ValueError('When overlap_grad_reduce is True, config.no_sync_func must be None; '
'a custom no_sync_func is not supported when overlapping grad-reduce')
config.no_sync_func = [model_chunk.no_sync for model_chunk in self.model]
if len(self.model) == 1:
config.no_sync_func = config.no_sync_func[0]
config.finalize_model_grads_func = self.finalize_model_grads
@ray.remote(resources={"NPU": 0.7})
class ActorHybridWorker(ActorHybridWorkerBase):
pass
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。