From ff8f71011bd5043a1b3fdae8c0d4587dc592c979 Mon Sep 17 00:00:00 2001 From: tangmengcheng <745274877@qq.com> Date: Tue, 10 Jun 2025 20:29:54 +0800 Subject: [PATCH] profiler modify --- cli/train_grpo.py | 2 ++ mindspeed_rl/workers/actor_hybrid_worker.py | 10 +++++----- mindspeed_rl/workers/integrated_worker.py | 5 +---- mindspeed_rl/workers/reference_woker.py | 11 ++++++++++- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cli/train_grpo.py b/cli/train_grpo.py index 5f822dfc..536dc3d4 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -89,6 +89,7 @@ def train(config): megatron_config=actor_config, rl_config=rl_config, generate_config=generate_config, + profiler_config=profiler_config["integrated"], model_provider=gpt_model_provider, tokenizer=tokenizer, initialize_func=initialize_megatron, @@ -102,6 +103,7 @@ def train(config): megatron_config=ref_config, rl_config=rl_config, generate_config=generate_config, + profiler_config=profiler_config["integrated"], model_provider=gpt_model_provider, tokenizer=tokenizer, initialize_func=initialize_megatron, diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index 0a81054a..8df75436 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -72,7 +72,7 @@ class ActorHybridWorkerBase(BaseWorker): self.actor_hybrid = None self.actor_offloader = None self.actor_profiler = None - self.prof_iteration = 1 + self.actor_prof_iteration = 1 def initialize(self): self.setup_distributed_rank() @@ -160,7 +160,7 @@ class ActorHybridWorkerBase(BaseWorker): sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None actor_update_profiler = profiler_start(self.profiler_config, role="actor_update", - profiler_iteration=self.prof_iteration) + profiler_iteration=self.actor_prof_iteration) MsProbe.debugger_start(self.model[0], tag='actor_update') start_time_defined = False @@ -195,7 +195,7 @@ class ActorHybridWorkerBase(BaseWorker): profiler_step(actor_update_profiler) MsProbe.debugger_stop(tag='actor_update') MsProbe.step() - self.prof_iteration += 1 + self.actor_prof_iteration += 1 start_sharding_exit_train = time.time() self.sharding_manager.exit_train_mode() sharding_train_interval += (time.time() - start_sharding_exit_train) @@ -230,7 +230,7 @@ class ActorHybridWorkerBase(BaseWorker): use_vllm=True) if self.rl_config.guarantee_order else None actor_generate_profiler = profiler_start(self.profiler_config, role="actor_generate", - profiler_iteration=self.prof_iteration) + profiler_iteration=self.actor_prof_iteration) MsProbe.debugger_start(self.inference_model.model, tag='actor_generate_sequences') start_time_defined = False @@ -311,7 +311,7 @@ class ActorHybridWorkerBase(BaseWorker): use_vllm=False) if self.rl_config.guarantee_order else None actor_compute_log_prob_profiler = profiler_start(self.profiler_config, role="actor_compute_log_prob", - profiler_iteration=self.prof_iteration) + profiler_iteration=self.actor_prof_iteration) MsProbe.debugger_start(self.model[0], tag='actor_compute_log_prob') start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index 5a4de8f8..debafffb 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -14,7 +14,7 @@ from mindspeed_rl.config_cls.generate_config import GenerateConfig from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig, MsprobeConfig from mindspeed_rl.utils.tokenizer import BaseTokenizer from mindspeed_rl.workers.resharding.megatron_sharding_manager import MegatronOffLoader -from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step +from mindspeed_rl.utils.utils import mstx_timer_decorator from mindspeed_rl.utils.utils import MsProbe from mindspeed_rl.workers.actor_hybrid_worker import ActorHybridWorkerBase @@ -117,8 +117,6 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB cumulate=True ) ) - compute_log_prob_profiler = profiler_start(self.profiler_config, role="reference_compute_log_prob", - profiler_iteration=self.prof_iteration) MsProbe.debugger_start(model=self.ref_model, tag="reference_compute_log_prob") if self.ref_forward_micro_batch_size is not None: with temporary_micro_batch_size( @@ -129,7 +127,6 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB ReferenceWorkerBase.compute_ref_log_prob(self) else: ReferenceWorkerBase.compute_ref_log_prob(self) - profiler_step(compute_log_prob_profiler) MsProbe.debugger_stop("reference_compute_log_prob") start_offload_time = time.time() self.ref_manager.offload_param() diff --git a/mindspeed_rl/workers/reference_woker.py b/mindspeed_rl/workers/reference_woker.py index 942582bb..946657c8 100644 --- a/mindspeed_rl/workers/reference_woker.py +++ b/mindspeed_rl/workers/reference_woker.py @@ -10,6 +10,8 @@ import torch from mindspeed_rl.config_cls.megatron_config import MegatronConfig from mindspeed_rl.config_cls.rl_config import RLConfig from mindspeed_rl.config_cls.generate_config import GenerateConfig +from mindspeed_rl.config_cls.mindstudio_config import ProfilerConfig +from mindspeed_rl.utils.utils import mstx_timer_decorator, profiler_start, profiler_step from mindspeed_rl.models.reference import Reference from mindspeed_rl.utils.pad_process import truncate_rows from mindspeed_rl.utils.tokenizer import BaseTokenizer @@ -31,6 +33,7 @@ class ReferenceWorkerBase(BaseWorker): initialize_func: Callable Function to initialize the model and environment. tokenizer: BaseTokenizer = None Object to retrieve the tokenizer. get_megatron_module: Callable = megatron_module from get_megatron_module. + profiler_config: ProfilerConfig, Configuration for profiling. **kwargs: Additional parameters for base class argument passing. """ @@ -43,6 +46,7 @@ class ReferenceWorkerBase(BaseWorker): initialize_func: Callable, tokenizer: BaseTokenizer = None, get_megatron_module: Callable = None, + profiler_config: ProfilerConfig = None, **kwargs ): super().__init__( @@ -53,9 +57,11 @@ class ReferenceWorkerBase(BaseWorker): initialize_func=initialize_func, tokenizer=tokenizer, get_megatron_module=get_megatron_module, + profiler_config=profiler_config, **kwargs ) self.reference = None + self.ref_prof_iteration = 1 def initialize(self): self.setup_distributed_rank() @@ -92,6 +98,8 @@ class ReferenceWorkerBase(BaseWorker): sorted_indexes = self.get_dp_range_indexes(experience_count, use_vllm=False) if self.rl_config.guarantee_order else None + compute_log_prob_profiler = profiler_start(self.profiler_config, role="reference_compute_log_prob", + profiler_iteration=self.ref_prof_iteration) start_time_defined = False while self.all_consumed(experience_consumer_stage, sorted_indexes) > 0: batch_data, index = self.dispatch_transfer_dock_data(experience_consumer_stage, @@ -130,7 +138,8 @@ class ReferenceWorkerBase(BaseWorker): cumulate=True ) ) - + profiler_step(compute_log_prob_profiler) + self.ref_prof_iteration += 1 parallel_state = get_parallel_state() use_vllm = False if is_pipeline_last_stage(parallel_state, use_vllm) and get_tensor_model_parallel_rank(parallel_state, use_vllm) == 0: -- Gitee