From 59c165417dfb7f4a822e122c76b5c5f8be13ff0b Mon Sep 17 00:00:00 2001 From: htwang Date: Tue, 27 May 2025 22:09:05 +0800 Subject: [PATCH] add rule reward for multimodal --- cli/train_grpo.py | 9 ++-- mindspeed_rl/config_cls/rl_config.py | 4 ++ mindspeed_rl/models/rule_verifier.py | 39 +++++++++++++--- mindspeed_rl/utils/utils.py | 34 +++++++------- mindspeed_rl/workers/rule_reward.py | 67 +++++++++++++++++++--------- requirements.txt | 4 ++ 6 files changed, 111 insertions(+), 46 deletions(-) diff --git a/cli/train_grpo.py b/cli/train_grpo.py index b499ee27..014bd233 100644 --- a/cli/train_grpo.py +++ b/cli/train_grpo.py @@ -46,10 +46,10 @@ def train(config): MsProbe.config_init(msprobe_config) MsProbe.save_configs({ - 'actor': eval(str(actor_config.dict())), - 'ref': eval(str(ref_config.dict())), - 'reward': eval(str(reward_config.dict())), - 'rl': eval(str(rl_config.dict())), + 'actor': eval(str(actor_config.dict())), + 'ref': eval(str(ref_config.dict())), + 'reward': eval(str(reward_config.dict())), + 'rl': eval(str(rl_config.dict())), 'generate': eval(str(generate_config.dict())) }) @@ -548,6 +548,7 @@ def main(config): rl_config = RLConfig(config.get("rl_config")) with open(os.path.join(cur_file_dir, rl_config.runtime_env_path)) as file: runtime_env = yaml.safe_load(file) + runtime_env["env_vars"]["IS_MULTIMODAL"] = str(rl_config.is_multimodal) logger.info(f"ray init with runtime_env: {runtime_env}") ray.init(runtime_env=runtime_env) diff --git a/mindspeed_rl/config_cls/rl_config.py b/mindspeed_rl/config_cls/rl_config.py index 44e65d3d..e2d14038 100644 --- a/mindspeed_rl/config_cls/rl_config.py +++ b/mindspeed_rl/config_cls/rl_config.py @@ -50,6 +50,8 @@ class RLConfig(BaseConfig): blocking: Whether to enable blocking mode (default: False) num_cpus_for_local_task: Number of CPUs for local ray task (default: 1) num_cpus_for_placement_group: Number of CPUs for ray worker placement group + + is_multimodal: Whether base model is a multimodal model or not (default: False) # Default values can still be defined if no config is provided ''' @@ -104,6 +106,8 @@ class RLConfig(BaseConfig): self.adv_dispatch_size = None self.actor_update_dispatch_size = None + self.is_multimodal = False + self.update(config_dict) self.n_samples_per_prompt = config_dict.get('n_samples_per_prompt', 1) diff --git a/mindspeed_rl/models/rule_verifier.py b/mindspeed_rl/models/rule_verifier.py index b88d833b..b49a7071 100644 --- a/mindspeed_rl/models/rule_verifier.py +++ b/mindspeed_rl/models/rule_verifier.py @@ -7,6 +7,7 @@ import threading import logging import torch +from mathruler.grader import extract_boxed_content, grade_answer from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.utils.math_eval_toolkit.grader import math_equal @@ -19,7 +20,7 @@ logger = Loggers("Rule verify") class GlobalProcessPool: _instance = None _lock = threading.Lock() - + def __init__(self, max_workers=16, reset_threshold=100000): self.max_workers = max_workers self.reset_threshold = reset_threshold @@ -27,17 +28,17 @@ class GlobalProcessPool: self.executor = None self.logger = logging.getLogger(__name__) self._initialize_executor() - + def _initialize_executor(self): """Initialize a new ProcessPoolExecutor and reset task counter.""" if self.executor is not None: self.executor.shutdown(wait=False) self.executor = None - gc.collect() + gc.collect() self.executor = ProcessPoolExecutor(max_workers=self.max_workers) self.task_counter = 0 self.logger.warning(f"Initialized ProcessPoolExecutor with {self.max_workers} workers") - + @classmethod def get_instance(cls, max_workers=16, reset_threshold=100000) -> 'GlobalProcessPool': """Get or create the singleton instance of GlobalProcessPool.""" @@ -46,16 +47,16 @@ class GlobalProcessPool: if cls._instance is None: cls._instance = cls(max_workers=max_workers, reset_threshold=reset_threshold) return cls._instance - + def submit(self, fn, *args, **kwargs): """ Submit a task to the executor with automatic recovery and periodic reset. - + Args: fn: Function to execute *args: Positional arguments for the function **kwargs: Keyword arguments for the function - + Returns: Future object representing the computation """ @@ -379,3 +380,27 @@ def reasoning_steps_reward(sequences, *args, **kwargs): scores = [min(1.0, count / 3) for count in matches] return scores + + +def math_format_reward(predict_str: str) -> float: + """ + Reward function that checks if the completion has a specific format for math questions. + """ + pattern = re.compile(r".*.*\\boxed\{.*\}.*", re.DOTALL) + format_match = re.fullmatch(pattern, predict_str) + return 1.0 if format_match else 0.0 + + +def math_acc_reward(predict_str: str, ground_truth: str) -> float: + """ + Reward function that checks if the answer is right by `mathruler`. + """ + answer = extract_boxed_content(predict_str) + return 1.0 if grade_answer(answer, ground_truth) else 0.0 + + +def math_compute_score(predict_str: str, ground_truth: str, acc_ratio=0.9, format_ratio=0.1) -> float: + """ + Compute score for math questions by format and accuary reward. + """ + return acc_ratio * math_acc_reward(predict_str, ground_truth) + format_ratio * math_format_reward(predict_str) \ No newline at end of file diff --git a/mindspeed_rl/utils/utils.py b/mindspeed_rl/utils/utils.py index 8cb2f62c..44d5a631 100644 --- a/mindspeed_rl/utils/utils.py +++ b/mindspeed_rl/utils/utils.py @@ -97,7 +97,7 @@ def append_to_dict(data: Dict, new_data: Dict): def num_floating_point_operations(args, batch_size): """ Calculate the number of floating-point operations for a given model configuration and batch size. - + Args: args (object): An object containing various model configuration parameters, including: - kv_channels: The number of key-value channels in attention layers. @@ -160,7 +160,7 @@ def get_batch_metrices_mean(metrics_list: List[Dict]) -> Dict[str, Tensor]: metrics_list: A list of dictionaries, where each dictionary contains metrics as key-value pairs. Returns: - metrics_mean: A dictionary where each key is a metric name and + metrics_mean: A dictionary where each key is a metric name and each value is the mean of that metric across all batches. """ batch_metrics = {} @@ -179,7 +179,7 @@ def metrics_post_processing(metrics) -> Dict[str, Tensor]: metrics_list: A list of dictionaries, where each dictionary contains metrics as key-value pairs. Returns: - metrics_mean: A dictionary where each key is a metric name and + metrics_mean: A dictionary where each key is a metric name and each value is the mean of that metric across all batches. """ new_metrics = {} @@ -214,7 +214,7 @@ def metrics_sort(metrics, time_all) -> Dict[str, Tensor]: reference_start_time = metrics.pop('start_time/reference_model', None) reference_end_time = metrics.pop('end_time/reference', None) - + if old_log_p_end_time is None: old_log_p_end_time = reference_end_time custom_order.remove('timing/old_log_p') @@ -224,14 +224,14 @@ def metrics_sort(metrics, time_all) -> Dict[str, Tensor]: if "timing/rule_reward" in metrics.keys(): reward_start_time = metrics.pop('start_time/rule_reward', None) reward_end_time = metrics.pop('end_time/rule_reward', None) - non_overlap_rule_reward_time = max(reward_end_time - max(old_log_p_end_time, reward_start_time), 0) + non_overlap_rule_reward_time = max(reward_end_time - max(old_log_p_end_time, reward_start_time), 0) metrics["timing/non_overlap_rule_reward"] = non_overlap_rule_reward_time if "timing/reward_model" in metrics.keys(): reward_start_time = metrics.pop('start_time/reward_model', None) reward_end_time = metrics.pop('end_time/reward_model', None) - non_overlap_reward_model_time = max(reward_end_time - max(old_log_p_end_time, reward_start_time), 0) + non_overlap_reward_model_time = max(reward_end_time - max(old_log_p_end_time, reward_start_time), 0) metrics["timing/non_overlap_reward_model"] = non_overlap_reward_model_time - + metrics["timing/non_overlap_reference_model"] = non_overlap_reference_model_time metrics["timing/non_overlap_adv"] = non_overlap_adv_time @@ -250,7 +250,7 @@ def metrics_sort(metrics, time_all) -> Dict[str, Tensor]: def compute_tps(compute_kwargs, metrics_result, gbs, n_samples, time_all): - + actor_resource = compute_kwargs.get('actor_resource', {}) reference_resource = compute_kwargs.get('reference_resource', {}) reward_resource = compute_kwargs.get('reward_resource', None) @@ -325,14 +325,14 @@ class MsProbe: if not msprobe_config.msprobe: return cls.config = msprobe_config - + try: from msprobe.core import SingleSave from msprobe.pytorch import PrecisionDebugger except Exception as e: print("import msprobe error, msprobe not enabled") return - + cls.saver = SingleSave(cls.config.dump_path) if cls.need_debugger(): step = [f"{cls.config.step_start}-{cls.config.step_end}"] @@ -340,7 +340,7 @@ class MsProbe: cls.enabled = True print("msprobe enabled") - + @classmethod def save_configs(cls, data): if not cls.enabled: @@ -348,7 +348,7 @@ class MsProbe: if not cls.config.configurations_dump: return cls.saver.save_config(data) - + @classmethod def save_data(cls, data): if not cls.enabled: @@ -356,7 +356,7 @@ class MsProbe: if not cls.config.key_data_dump: return cls.saver.save(data) - + @classmethod def need_debugger(cls): if cls.config.reference_dump or cls.config.actor_train_dump or cls.config.actor_infer_dump: @@ -374,7 +374,7 @@ class MsProbe: if tag == "actor_generate_sequences" and cls.config.actor_infer_dump: return True return False - + @classmethod def debugger_start(cls, model=None, tag=None): if not cls.enabled: @@ -524,7 +524,7 @@ def profiler_start(profiler_config, role="profiler_data", profiler_iteration=Non if not profiler_config: return None if profiler_iteration is not None and ( - profiler_iteration < profiler_config.profile_step_start or + profiler_iteration < profiler_config.profile_step_start or profiler_iteration >= profiler_config.profile_step_end): return None if profiler_config.stage == "all" and role != profiler_config.role: @@ -541,3 +541,7 @@ def profiler_start(profiler_config, role="profiler_data", profiler_iteration=Non def profiler_step(profiler): if profiler: profiler.step() + + +def is_multimodal(): + return eval(os.getenv("IS_MULTIMODAL", "False")) diff --git a/mindspeed_rl/workers/rule_reward.py b/mindspeed_rl/workers/rule_reward.py index fac16689..a68dba58 100644 --- a/mindspeed_rl/workers/rule_reward.py +++ b/mindspeed_rl/workers/rule_reward.py @@ -1,11 +1,12 @@ # Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. import ray +import torch from transformers import AutoTokenizer -from mindspeed_rl.models.rule_verifier import compute_verifier_score +from mindspeed_rl.models.rule_verifier import compute_verifier_score, math_compute_score from mindspeed_rl.utils.loggers import Loggers from mindspeed_rl.trainer.utils.transfer_dock import pad_experience -from mindspeed_rl.utils.utils import get_current_dp_range_indexes +from mindspeed_rl.utils.utils import get_current_dp_range_indexes, is_multimodal logger = Loggers("rule_reward") @@ -20,9 +21,10 @@ class RuleReward(object): self.tokenizer = tokenizer self.hf_tokenizer = AutoTokenizer.from_pretrained(megatron_config.tokenizer_name_or_path, trust_remote_code=True) - - def init_transfer_dock(self, td): + + def init_transfer_dock(self, td, mm_td=None): self.td = td + self.mm_td = mm_td def compute_rm_score(self): experience_consumer_stage = 'rule_reward' @@ -45,22 +47,47 @@ class RuleReward(object): if batch_data and index: batch_data = pad_experience(batch_data, pad_token_id) # multiple, tp_size - if "categories" in batch_data.keys(): - use_verifier_mask = batch_data["categories"][:, 0].squeeze().bool() - selected_index = [index[i] for i in range(len(index)) if use_verifier_mask[i]] - index = selected_index - if not index: - continue - if "categories" in batch_data.keys(): - batch_data = {key: value[use_verifier_mask] if key != 'prompts' else value[ - use_verifier_mask[::self.n_samples_per_prompt]] for key, value in batch_data.items()} - ignore_token = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod + if not is_multimodal(): + if "categories" in batch_data.keys(): + use_verifier_mask = batch_data["categories"][:, 0].squeeze().bool() + selected_index = [index[i] for i in range(len(index)) if use_verifier_mask[i]] + index = selected_index + if not index: + continue + if "categories" in batch_data.keys(): + batch_data = {key: value[use_verifier_mask] if key != 'prompts' else value[ + use_verifier_mask[::self.n_samples_per_prompt]] for key, value in batch_data.items()} + ignore_token = self.tokenizer.pad if self.tokenizer.pad else self.tokenizer.eod + + token_level_rewards, metrics = compute_verifier_score(batch_data, self.megatron_config, self.rl_config, + self.hf_tokenizer, ignore_token) + + for key, value in metrics.items(): + ray.get(self.td.update_metrics.remote(key, value=value, cumulate=True)) + + output = {"rm_scores": token_level_rewards, "token_level_rewards": token_level_rewards} + else: + mm_columns = ray.get(self.mm_td.get_columns.remote(experience_consumer_stage)) + batch_mm_data = ray.get(self.mm_td.get_experience.remote(mm_columns, index)) + batch_data.update(batch_mm_data) - token_level_rewards, metrics = compute_verifier_score(batch_data, self.megatron_config, self.rl_config, - self.hf_tokenizer, ignore_token) - - for key, value in metrics.items(): - ray.get(self.td.update_metrics.remote(key, value=value, cumulate=True)) + reward_tensor = torch.zeros((batch_data['responses'].size(0), 1), dtype=torch.float32) + original_shape = reward_tensor.shape + responses = batch_data['responses'] + response_strs = self.hf_tokenizer.batch_decode(responses, skip_special_tokens=True) + labels = [] + for _ in range(self.n_samples_per_prompt): + for label in batch_data['labels']: + labels.append(label) - output = {"rm_scores": token_level_rewards, "token_level_rewards": token_level_rewards} + for i, (response_str, label) in enumerate(zip(response_strs, labels)): + token_level_rewards = math_compute_score(response_str, label) + reward_tensor[i, 0] = token_level_rewards + rm_scores = reward_tensor + reward_tensor_reshaped = reward_tensor.reshape(-1, self.n_samples_per_prompt) + reward_mean = reward_tensor_reshaped.mean(dim=1, keepdim=True) + reward_std = reward_tensor_reshaped.std(dim=1, keepdim=True) + 1e-6 + reward_tensor_normalized = (reward_tensor_reshaped - reward_mean) / reward_std + reward_tensor = reward_tensor_normalized.reshape(original_shape) + output = {"rm_scores": rm_scores, "token_level_rewards": reward_tensor} self.td.put_experience.remote(data_dict=output, indexes=index) diff --git a/requirements.txt b/requirements.txt index 49d9a910..d431fd76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,7 @@ tensordict==0.1.2 transformers==4.48.2 word2number wandb +pillow==11.2.1 +av==14.3.0 +qwen_vl_utils==0.0.11 +mathruler -- Gitee