From 3bb93698a1da6ff4524e087d46395be2c1f4e2dd Mon Sep 17 00:00:00 2001 From: peng-hengduo Date: Tue, 19 Aug 2025 10:55:00 +0800 Subject: [PATCH 1/4] update qwen2vl_dpo_trainer to pretrain_vlm method --- .../tasks/rl/dpo/qwen2vl_dpo_trainer.py | 104 ++++++++++++------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py index 0868f1df..a27fda21 100644 --- a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py +++ b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py @@ -1,5 +1,6 @@ # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. from copy import deepcopy +from typing import Dict, Any from functools import partial import torch @@ -10,6 +11,7 @@ from megatron.training.checkpointing import load_checkpoint from megatron.training.global_vars import set_args from megatron.training.training import get_model from mindspeed_mm.models.qwen2vl_model import Qwen2VLModel +from mindspeed_mm.models.vlm_model import VLMModel from mindspeed_mm.tasks.finetune.lora.utils import is_enable_lora from mindspeed_mm.tasks.rl.dpo.dpo_trainer import DPOTrainer from mindspeed_mm.tasks.rl.dpo.qwen2vl_dpo_model import Qwen2VLDPOModel @@ -52,31 +54,57 @@ class Qwen2VLDPOTrainer(DPOTrainer): ) self.disable_dropout() - def model_provider(self, pre_process=True, post_process=True): + def model_provider(self, pre_process=True, post_process=True, modules=None): """Builds the model.""" + if modules is None: + modules = ['image_encoder', 'audio_encoder', 'text_decoder'] + args = get_args() - print_rank_0("building QWen2VL model ...") + print_rank_0("building VLMModel ...") vlm_config = deepcopy(args.mm.model) # distinguish model construct stage when pipeline parallel vlm_config.pre_process = pre_process vlm_config.post_process = post_process - if vlm_config.image_encoder: - vlm_config.image_encoder.vision_encoder = get_model_config(vlm_config.image_encoder.vision_encoder) - vlm_config.image_encoder.vision_projector = get_model_config(vlm_config.image_encoder.vision_projector) - vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) + self._configure_modules(vlm_config, modules) - model = Qwen2VLModel(vlm_config) + model = VLMModel(vlm_config) - model.freeze(freeze_image_encoder=getattr(vlm_config.image_encoder.vision_encoder, 'freeze', True), - freeze_image_projection=getattr(vlm_config.image_encoder.vision_projector, 'freeze', True)) - else: - vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) - model = Qwen2VLModel(vlm_config) + self._apply_freezing(model, vlm_config) return model + + def _configure_modules(self, vlm_config, modules): + """Configure each module based on the modules list.""" + module_configs = { + 'image_encoder': self._configure_image_encoder, + 'audio_encoder': self._configure_audio_encoder, + 'text_decoder': self._configure_text_decoder + } + + for module_name, config_func in module_configs.items(): + if module_name in modules and hasattr(vlm_config, module_name): + config_func(vlm_config) + else: + setattr(vlm_config, module_name, None) + + + def _configure_image_encoder(self, vlm_config): + """Configure image encoder module.""" + vlm_config.image_encoder.vision_encoder = get_model_config(vlm_config.image_encoder.vision_encoder) + vlm_config.image_encoder.vision_projector = get_model_config(vlm_config.image_encoder.vision_projector) + + + def _configure_audio_encoder(self, vlm_config): + """Configure audio encoder module.""" + vlm_config.audio_encoder.audio_encoder = get_model_config(vlm_config.audio_encoder.audio_encoder) + + def _configure_text_decoder(self, vlm_config): + """Configure text decoder module.""" + vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) + def disable_dropout(self): """ disable dropout @@ -88,36 +116,40 @@ class Qwen2VLDPOTrainer(DPOTrainer): args_.retro_encoder_attention_dropout = 0.0 set_args(args_) - @staticmethod - def get_batch(data_iterator): + def get_batch(self, data_iterator): """Generate a batch.""" if data_iterator is not None: batch = next(data_iterator) else: raise ValueError("Data iterator is None. Unable to retrieve batch.") - input_ids = batch['input_ids'].to(torch.cuda.current_device()) - labels = batch['labels'].to(torch.cuda.current_device()) - attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) - has_image = 'pixel_values' in batch and 'image_grid_thw' in batch + self.move_to_device(batch, get_args().params_dtype) has_video = 'pixel_values_videos' in batch and 'video_grid_thw' in batch - if has_image or has_video: - if has_image: - pixel_values = batch['pixel_values'].to(torch.cuda.current_device()) - image_grid_thw = batch['image_grid_thw'].to(torch.cuda.current_device()) - if has_video: - pixel_values = batch['pixel_values_videos'].to(torch.cuda.current_device()) - image_grid_thw = batch['video_grid_thw'].to(torch.cuda.current_device()) - else: # 只有文本 - pixel_values = None - image_grid_thw = None - batch = { - 'input_ids': input_ids, - 'labels': labels, - 'attention_mask': attention_mask, - 'pixel_values': pixel_values, - 'image_grid_thw': image_grid_thw - } - return batch['input_ids'], batch['labels'], batch['attention_mask'], batch['pixel_values'], batch['image_grid_thw'] + + input_ids = batch.pop('input_ids').to(torch.cuda.current_device()) + labels = batch.pop('labels').to(torch.cuda.current_device()) + attention_mask =batch.pop('attention_mask').to(torch.cuda.current_device()) + + if has_video: + pixel_values = batch.pop('pixel_values_videos').to(torch.cuda.current_device()) + image_grid_thw = batch.pop('video_grid_thw').to(torch.cuda.current_device()) + else: + pixel_values = batch.pop('pixel_values_videos', None) + image_grid_thw = batch.pop('video_grid_thw', None) + if pixel_values is not None: + pixel_values = pixel_values.to(torch.cuda.current_device()) + if image_grid_thw is not None: + image_grid_thw = image_grid_thw.to(torch.cuda.current_device()) + return input_ids, labels, attention_mask, pixel_values, image_grid_thw + + def move_to_device(self, batch: Dict[str, Any], float_dtype: str): + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + dtype = float_dtype if torch.is_floating_point(v) else None + batch[k] = v.to(device=torch.cuda.current_device(), dtype=dtype) + elif isinstance(v, list) and all(isinstance(t, torch.Tensor) for t in v): + batch[k] = [t.to(device=torch.cuda.current_device(), + dtype=float_dtype if torch.is_floating_point(t) else None) + for t in v] def forward_step(self, data_iterator, model): """DPO Forward training step. -- Gitee From ca9c7be19bc5c30d1b6d3807febb97f3d0b4a3ef Mon Sep 17 00:00:00 2001 From: peng-hengduo Date: Tue, 19 Aug 2025 10:57:50 +0800 Subject: [PATCH 2/4] delete pretrain_qwen2vl and modify in SECURITYNOTE.md --- docs/SECURITYNOTE.md | 1 - pretrain_qwen2vl.py | 148 ------------------------------------------- 2 files changed, 149 deletions(-) delete mode 100644 pretrain_qwen2vl.py diff --git a/docs/SECURITYNOTE.md b/docs/SECURITYNOTE.md index b50e8988..65bb0b7a 100644 --- a/docs/SECURITYNOTE.md +++ b/docs/SECURITYNOTE.md @@ -97,7 +97,6 @@ MindSpeed MM 暂时未发布wheel包,无正式对外公开接口,所有功 - [pretrain_whisper](https://gitee.com/ascend/MindSpeed-MM/blob/master/pretrain_whisper.py) - [pretrain_ae](https://gitee.com/ascend/MindSpeed-MM/blob/master/pretrain_ae.py) - [pretrain_internvl](https://gitee.com/ascend/MindSpeed-MM/blob/master/pretrain_internvl.py) -- [pretrain_qwen2vl](https://gitee.com/ascend/MindSpeed-MM/blob/master/pretrain_qwen2vl.py) - [pretrain_vlm](https://gitee.com/ascend/MindSpeed-MM/blob/master/pretrain_vlm.py) ## 通信安全加固 diff --git a/pretrain_qwen2vl.py b/pretrain_qwen2vl.py deleted file mode 100644 index b7b5ae33..00000000 --- a/pretrain_qwen2vl.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -"""Pretrain QWEN2VL.""" -from copy import deepcopy -from functools import partial - -import mindspeed.megatron_adaptor -import torch - -from datasets import Dataset -from megatron.core import mpu -from megatron.core.enums import ModelType -from megatron.training import get_args, print_rank_0 -from megatron.training.utils import average_losses_across_data_parallel_group - -from mindspeed_mm.configs.config import mm_extra_args_provider -from mindspeed_mm.data import build_mm_dataloader, build_mm_dataset -from mindspeed_mm.data.data_utils.utils import build_iterations -from mindspeed_mm.models.qwen2vl_model import Qwen2VLModel -from mindspeed_mm.training import pretrain -from mindspeed_mm.utils.transformer_model_config import get_model_config -from mindspeed_mm.patchs import dummy_optimizer_patch - - -def model_provider(pre_process=True, post_process=True): - """Builds the model.""" - args = get_args() - print_rank_0("building QWen2VL model ...") - vlm_config = deepcopy(args.mm.model) - - # distinguish model construct stage when pipeline parallel - vlm_config.pre_process = pre_process - vlm_config.post_process = post_process - - if vlm_config.image_encoder: - vlm_config.image_encoder.vision_encoder = get_model_config(vlm_config.image_encoder.vision_encoder) - vlm_config.image_encoder.vision_projector = get_model_config(vlm_config.image_encoder.vision_projector) - vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) - - model = Qwen2VLModel(vlm_config) - - model.freeze(freeze_image_encoder=getattr(vlm_config.image_encoder.vision_encoder, 'freeze', True), \ - freeze_image_projection=getattr(vlm_config.image_encoder.vision_projector, 'freeze', True)) - else: - vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) - model = Qwen2VLModel(vlm_config) - - return model - - -def get_batch(data_iterator): - """Generate a batch.""" - if data_iterator is not None: - batch = next(data_iterator) - else: - raise ValueError("Data iterator is None. Unable to retrieve batch.") - input_ids = batch['input_ids'].to(torch.cuda.current_device()) - labels = batch['labels'].to(torch.cuda.current_device()) - attention_mask = batch['attention_mask'].to(torch.cuda.current_device()) - has_image = 'pixel_values' in batch and 'image_grid_thw' in batch - has_video = 'pixel_values_videos' in batch and 'video_grid_thw' in batch - if has_image or has_video: - if has_image: - pixel_values = batch['pixel_values'].to(torch.cuda.current_device()) - image_grid_thw = batch['image_grid_thw'].to(torch.cuda.current_device()) - if has_video: - pixel_values = batch['pixel_values_videos'].to(torch.cuda.current_device()) - image_grid_thw = batch['video_grid_thw'].to(torch.cuda.current_device()) - else: # 只有文本 - pixel_values = None - image_grid_thw = None - batch = { - 'input_ids': input_ids, - 'labels': labels, - 'attention_mask': attention_mask, - 'pixel_values': pixel_values, - 'image_grid_thw': image_grid_thw - } - return batch['input_ids'], batch['labels'], batch['attention_mask'], batch['pixel_values'], batch['image_grid_thw'] - - -def loss_func(output_tensor): - """Loss function.""" - args = get_args() - loss = output_tensor['loss'].mean() - loss_dir = {} - if args.log_tps: - B, S, _ = output_tensor['logits'].shape - dp_size = torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) - tokens_per_sample = torch.tensor(S, device=output_tensor['logits'].device) / dp_size - torch.distributed.all_reduce(tokens_per_sample, group=mpu.get_data_parallel_group()) - loss_dir["tokens per sample"] = tokens_per_sample - averaged_loss = average_losses_across_data_parallel_group([loss]) - loss_dir["loss"] = averaged_loss[0] - loss = loss.unsqueeze(0).clone() - return loss, loss_dir - - -def forward_step(data_iterator, model): - """Forward step.""" - - input_ids, labels, attention_mask, pixel_values, image_grid_thw = get_batch(data_iterator) - - output_tensor = model(input_ids=input_ids, pixel_values=pixel_values, image_grid_thw=image_grid_thw, - attention_mask=attention_mask, labels=labels) - return output_tensor, loss_func - - -def train_valid_test_datasets_provider(train_val_test_num_samples): - """Build train, valid, and test datasets.""" - args = get_args() - data_config = args.mm.data - datasets = build_mm_dataset(data_config.dataset_param) - build_dataloader = partial(build_mm_dataloader, - dataloader_param=data_config.dataloader_param, - process_group=mpu.get_data_parallel_group(), - dataset_param=data_config.dataset_param, - consumed_samples=args.consumed_train_samples) - - if isinstance(datasets, tuple) and len(datasets) == 2: - train_dataset, val_dataset = datasets - train_dataloader = build_dataloader(train_dataset) - valid_dataloader = build_dataloader(val_dataset) - train_dataloader, val_dataloader, test_dataloader = build_iterations(train_dataloader, valid_dataloader) - else: - train_dataset = datasets - val_rate = getattr(data_config.dataset_param.basic_parameters, 'val_rate', 0.0) - if isinstance(train_dataset, Dataset) and val_rate > 0: - dataset = train_dataset.train_test_split(test_size=val_rate, seed=args.seed) - train_dataset, val_dataset = dataset['train'], dataset['test'] - train_dataloader = build_dataloader(train_dataset) - valid_dataloader = build_dataloader(val_dataset) - train_dataloader, val_dataloader, test_dataloader = build_iterations(train_dataloader, valid_dataloader) - else: - train_dataloader = build_dataloader(train_dataset) - train_dataloader, val_dataloader, test_dataloader = build_iterations(train_dataloader) - return train_dataloader, val_dataloader, test_dataloader - - -if __name__ == "__main__": - train_valid_test_datasets_provider.is_distributed = True - pretrain( - train_valid_test_datasets_provider, - model_provider, - ModelType.encoder_or_decoder, - forward_step, - extra_args_provider=mm_extra_args_provider, - args_defaults={"dataloader_type": "external"}, - ) -- Gitee From 06251e8596cf28029ec8dd351df1f634ebd79df9 Mon Sep 17 00:00:00 2001 From: peng-hengduo Date: Tue, 19 Aug 2025 11:25:48 +0800 Subject: [PATCH 3/4] update qwen2vl_dpo_trainer to pretrain_vlm method --- mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py index a27fda21..27dff2f2 100644 --- a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py +++ b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py @@ -127,7 +127,7 @@ class Qwen2VLDPOTrainer(DPOTrainer): input_ids = batch.pop('input_ids').to(torch.cuda.current_device()) labels = batch.pop('labels').to(torch.cuda.current_device()) - attention_mask =batch.pop('attention_mask').to(torch.cuda.current_device()) + attention_mask = batch.pop('attention_mask').to(torch.cuda.current_device()) if has_video: pixel_values = batch.pop('pixel_values_videos').to(torch.cuda.current_device()) -- Gitee From f4509d6c2cbb5b0e3fbd97a72f21f8fca72346a1 Mon Sep 17 00:00:00 2001 From: peng-hengduo Date: Tue, 19 Aug 2025 14:19:15 +0800 Subject: [PATCH 4/4] appends relevant code overlooked --- mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py index 27dff2f2..c97ee359 100644 --- a/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py +++ b/mindspeed_mm/tasks/rl/dpo/qwen2vl_dpo_trainer.py @@ -105,6 +105,21 @@ class Qwen2VLDPOTrainer(DPOTrainer): """Configure text decoder module.""" vlm_config.text_decoder = get_model_config(vlm_config.text_decoder) + def _apply_freezing(self, model, vlm_config): + """Apply freezing settings to the model.""" + has_image = hasattr(vlm_config, 'image_encoder') and vlm_config.image_encoder is not None + freeze_image_encoder = has_image and getattr(vlm_config.image_encoder.vision_encoder, 'freeze', True) + freeze_image_projection = has_image and getattr(vlm_config.image_encoder.vision_projector, 'freeze', False) + + has_audio = hasattr(vlm_config, 'audio_encoder') and vlm_config.audio_encoder is not None + freeze_audio_encoder = has_audio and getattr(vlm_config.audio_encoder.audio_encoder, 'freeze', True) + + model.freeze( + freeze_image_encoder=freeze_image_encoder, + freeze_image_projection=freeze_image_projection, + freeze_audio_encoder=freeze_audio_encoder + ) + def disable_dropout(self): """ disable dropout -- Gitee