代码拉取完成,页面将自动刷新
# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
import dataclasses
from typing import Callable
from contextlib import contextmanager
import time
import ray
import torch
from mindspeed_rl.config_cls.megatron_config import MegatronConfig
from mindspeed_rl.config_cls.rl_config import RLConfig
from mindspeed_rl.config_cls.generate_config import GenerateConfig
from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig
from mindspeed_rl.utils.tokenizer import BaseTokenizer
from mindspeed_rl.workers.resharding.megatron_off_loader import MegatronOffLoader
from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step
from mindspeed_rl.utils.utils import MsProbe
from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorkerBase
from mindspeed_rl.workers.reference_woker import ReferenceWorkerBase
from mindspeed_rl.workers.reward_woker import RewardWorkerBase
from mindspeed_rl.models.reference import Reference
from mindspeed_rl.models.reward import Reward
@ray.remote(resources={"NPU": 0.7})
class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerBase):
"""
IntegratedWorker class. This class implements the integrated worker for the Actor, Reference and Reward Worker.
Args:
megatron_config: MegatronConfig Configuration for Megatron-LM (e.g., model parallelism settings).
rl_config: RLConfig Configuration for reinforcement learning (e.g., PPO settings).
generate_config: GenerateConfig Configuration for generation/inference (e.g., vLLM settings).
model_provider: Callable Function to provide the model instance.
initialize_func: Callable Function to initialize the model and environment.
tokenizer: BaseTokenizer = None Object to retrieve the tokenizer.
get_megatron_module: Callable = megatron_module from get_megatron_module.
profiler_config: ProfilerConfig, Configuration for profiling.
msprobe_config: MsprobeConfig, Configuration for msprobe.
**kwargs: Additional parameters for base class argument passing.
"""
def __init__(
self,
megatron_config: MegatronConfig,
rl_config: RLConfig,
generate_config: GenerateConfig,
model_provider: Callable,
initialize_func: Callable,
tokenizer: BaseTokenizer = None,
get_megatron_module: Callable = None,
profiler_config: ProfilerConfig = None,
msprobe_config: MsprobeConfig = None,
**kwargs
):
# We use Actor as main worker, so only do init for Actor here.
ActorHybridWorkerBase.__init__(
self,
megatron_config,
rl_config,
generate_config,
model_provider=model_provider,
initialize_func=initialize_func,
tokenizer=tokenizer,
get_megatron_module=get_megatron_module,
profiler_config=profiler_config,
msprobe_config=msprobe_config,
**kwargs
)
self.actor_forward_micro_batch_size = rl_config.actor_forward_micro_batch_size
self.ref_forward_micro_batch_size = rl_config.ref_forward_micro_batch_size
self.vit_forward_micro_batch_size = rl_config.vit_forward_micro_batch_size
self.reference = None
self.ref_model = None
self.ref_manager = None
def initialize(self):
# Based on Actor
ActorHybridWorkerBase.initialize(self)
# Add Reference
self.ref_model = self.get_model(self.model_provider, self.model_type, wrap_with_ddp=False)
ref_model_load_path = getattr(
self.rl_config.integrated_mode_config, "ref_model_load_path", None
) if self.rl_config.integrated_mode_config is not None else None
self.load_checkpoint_with_path(self.ref_model, ref_model_load_path, ckpt_only=True)
self.ref_manager = MegatronOffLoader(self.ref_model, wrap_with_ddp=False)
self.ref_manager.offload_param()
self.reference = Reference(
self.ref_model,
megatron_config=self.megatron_config,
beta=self.rl_config.beta,
mini_batch_size=self.rl_config.mini_batch_size,
epochs=self.rl_config.epochs,
shuffle_mini_batch=self.rl_config.shuffle_mini_batch,
generate_config=self.generate_config,
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
temperature=self.generate_config.sampling_config["temperature"],
use_remove_padding=self.rl_config.use_remove_padding,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size,
ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size
)
MsProbe.config_init(self.msprobe_config)
@mstx_timer_decorator
def compute_ref_log_prob(self):
start_onload_time = time.time()
self.ref_manager.onload_param()
end_onload_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/ref_onload",
value=[round(end_onload_time, 4), round(start_onload_time, 4)],
cumulate=True
)
)
compute_log_prob_profiler = profiler_start(self.profiler_config, role="reference_compute_log_prob",
profiler_iteration=self.prof_iteration)
MsProbe.debugger_start(model=self.ref_model, tag="reference_compute_log_prob")
if self.ref_forward_micro_batch_size is not None:
with temporary_micro_batch_size(
worker=self.reference,
args=self.get_args(),
new_mbs=self.ref_forward_micro_batch_size
):
ReferenceWorkerBase.compute_ref_log_prob(self)
else:
ReferenceWorkerBase.compute_ref_log_prob(self)
profiler_step(compute_log_prob_profiler)
MsProbe.debugger_stop("reference_compute_log_prob")
start_offload_time = time.time()
self.ref_manager.offload_param()
torch.cuda.empty_cache()
end_offload_time = time.time()
ray.get(
self.td.update_metrics.remote(
"timing/ref_offload",
value=[round(end_offload_time, 4), round(start_offload_time, 4)],
cumulate=True
)
)
def compute_log_prob(self):
if self.actor_forward_micro_batch_size is not None:
with temporary_micro_batch_size(
worker=self.actor_hybrid.train_actor,
args=self.get_args(),
new_mbs=self.actor_forward_micro_batch_size
):
ActorHybridWorkerBase.compute_log_prob(self)
else:
ActorHybridWorkerBase.compute_log_prob(self)
def compute_image_embeds(self):
if self.vit_forward_micro_batch_size is not None:
with temporary_micro_batch_size(
worker=self.actor_hybrid.train_actor,
args=self.get_args(),
new_mbs=self.vit_forward_micro_batch_size
):
ActorHybridWorkerBase.compute_image_embeds(self)
else:
ActorHybridWorkerBase.compute_image_embeds(self)
def load_checkpoint_with_path(self, model, path, ckpt_only=False):
"""Load model checkpoint from a specified path with flexible control.
Args:
model: The model to load checkpoint into.
path: Path to the checkpoint file/directory. If None, use the path in megatron args.
ckpt_only: If True, only loads model weights (skips optimizer/RNG states).
"""
# Backup original arguments if needed
original_args = {
'no_load_optim': getattr(self.get_args(), "no_load_optim", None),
'no_load_rng': getattr(self.get_args(), "no_load_rng", None),
'load': getattr(self.get_args(), "load", None),
'iteration': getattr(self.get_args(), "iteration", None),
'finetune': getattr(self.get_args(), "finetune", None),
'consumed_train_samples': getattr(self.get_args(), "consumed_train_samples", None),
'consumed_valid_samples': getattr(self.get_args(), "consumed_valid_samples", None),
} if ckpt_only or path else {}
if ckpt_only:
self._set_args({
"no_load_optim": True,
"no_load_rng": True,
"finetune": True,
'consumed_train_samples': 0,
'consumed_valid_samples': 0
})
if path is not None:
self._set_args({"load": path})
self.load_checkpoint(model, None, None)
if original_args:
self._set_args(original_args)
def _set_args(self, arg_dict):
for key, value in arg_dict.items():
if hasattr(self.get_args(), key):
setattr(self.get_args(), key, value)
@contextmanager
def temporary_micro_batch_size(worker, args, new_mbs):
original_mbs = args.micro_batch_size
try:
worker.micro_batch_size = new_mbs
args.micro_batch_size = new_mbs
yield
finally:
worker.micro_batch_size = original_mbs
args.micro_batch_size = original_mbs
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。