diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 857ed67b8410af0eb32a3ad9682838ce3abf3055..86197079e5e48e001f1687a328b15592f5628d19 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -306,7 +306,7 @@ class MsProbe: hooked_model = [] @classmethod - def config_init(cls, msprobe_config): + def config_init(cls, msprobe_config, init_step=0): if not msprobe_config.msprobe: return cls.config = msprobe_config @@ -322,6 +322,7 @@ class MsProbe: if cls.need_debugger(): step = [f"{cls.config.step_start}-{cls.config.step_end}"] cls.debugger = PrecisionDebugger(task="statistics", level="L0", step=step, dump_path=cls.config.dump_path) + cls.debugger.set_init_step(init_step) cls.enabled = True print("msprobe enabled") diff --git a/mindspeed_rl/workers/actor_hybrid_worker.py b/mindspeed_rl/workers/actor_hybrid_worker.py index e4f5edb60f111478929a09a4add5cc98106dbdf1..3083de549d3c5d1561f740d0ea0d02d4c265b59a 100644 --- a/mindspeed_rl/workers/actor_hybrid_worker.py +++ b/mindspeed_rl/workers/actor_hybrid_worker.py @@ -121,7 +121,7 @@ class ActorHybridWorkerBase(BaseWorker): ) self.empty_cache() self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role) - MsProbe.config_init(self.msprobe_config) + MsProbe.config_init(self.msprobe_config, self.iteration + 1) def init_transfer_dock(self, td, mm_td): self.td = td diff --git a/mindspeed_rl/workers/integrated_worker.py b/mindspeed_rl/workers/integrated_worker.py index 175ee11ebb4753112b3b4096d8b4a72f6fa73ece..263a445772f7eddd2002b62045c78b56c280f230 100644 --- a/mindspeed_rl/workers/integrated_worker.py +++ b/mindspeed_rl/workers/integrated_worker.py @@ -111,7 +111,6 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB context_parallel_size=self.megatron_config.context_parallel_size, temperature=self.generate_config.sampling_config["temperature"] ) - MsProbe.config_init(self.msprobe_config) @mstx_timer_decorator def compute_ref_log_prob(self): diff --git a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml index 196eb9af9c609671c8be435b860efe27b5f95436..2776a7021e887d7ba80bc1632e724463df0401ab 100644 --- a/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml +++ b/tests/st/configs/test_grpo_trainer_qwen25_7b_integrated.yaml @@ -146,5 +146,5 @@ msprobe_config: configurations_dump: true actor_train_dump: true reference_dump: true - step_start: 0 - step_end: 0 \ No newline at end of file + step_start: 1 + step_end: 1 \ No newline at end of file diff --git a/tests/st/mindstudio/check_and_clean_mindstudio_output.py b/tests/st/mindstudio/check_and_clean_mindstudio_output.py index 676dead62b8cd2a63cb8ac76724f74d1d1ff61af..8be04bca92dcb9cea25e6f54b95dbed1770d8627 100644 --- a/tests/st/mindstudio/check_and_clean_mindstudio_output.py +++ b/tests/st/mindstudio/check_and_clean_mindstudio_output.py @@ -61,9 +61,9 @@ def check_msprobe_output(msprobe_dir: str) -> bool: logger.error(f'Configurations directory not found: {os.path.join(msprobe_dir, "configurations.json")}') return False - if not os.path.isfile(os.path.join(msprobe_dir, "data", "responses", "step0", "rank0", "responses.json")): + if not os.path.isfile(os.path.join(msprobe_dir, "data", "responses", "step1", "rank0", "responses.json")): logger.error(f'Msprobe key data response not found: ' - f'{os.path.join(msprobe_dir, "data", "responses", "step0", "rank0", "responses.json")}') + f'{os.path.join(msprobe_dir, "data", "responses", "step1", "rank0", "responses.json")}') return False if not os.path.isdir(os.path.join(msprobe_dir, "actor_update")):