From c13b0a057b4a5926234af446c2be2aa27246fd59 Mon Sep 17 00:00:00 2001 From: lmy Date: Tue, 19 Aug 2025 21:27:05 +0800 Subject: [PATCH] [feature]DanceGRPO flux train example --- examples/dancegrpo/README.md | 241 ++++++++++++++++++ examples/dancegrpo/data_dancegrpo.json | 7 + examples/dancegrpo/model_dancegrpo.json | 13 + .../dancegrpo/posttrain_flux_dancegrpo.sh | 109 ++++++++ .../preprocess_flux_rl_embeddings.sh | 35 +++ examples/dancegrpo/requirements-lint.txt | 76 ++++++ posttrain_flux_dancegrpo.py | 20 ++ 7 files changed, 501 insertions(+) create mode 100644 examples/dancegrpo/README.md create mode 100644 examples/dancegrpo/data_dancegrpo.json create mode 100644 examples/dancegrpo/model_dancegrpo.json create mode 100644 examples/dancegrpo/posttrain_flux_dancegrpo.sh create mode 100644 examples/dancegrpo/preprocess_flux_rl_embeddings.sh create mode 100644 examples/dancegrpo/requirements-lint.txt create mode 100644 posttrain_flux_dancegrpo.py diff --git a/examples/dancegrpo/README.md b/examples/dancegrpo/README.md new file mode 100644 index 00000000..d132ae24 --- /dev/null +++ b/examples/dancegrpo/README.md @@ -0,0 +1,241 @@ +# FLUX DanceGRPO 使用指南 + +

+

+ +## 目录 + +- [简介](#jump0) +- [环境安装](#jump1) + - [仓库拉取](#jump1.1) + - [环境搭建](#jump1.2) +- [权重下载](#jump2) +- [数据集准备及处理](#jump3) +- [训练](#jump4) + - [准备工作](#jump4.1) + - [启动训练](#jump4.2) + + + +## 简介 + +以 MindSpeed MM 仓库复现 [DanceGRPO](https://arxiv.org/abs/2505.07818) +后训练方法来帮助用户快速入门,前期需要完成代码仓、环境、数据集以及权重等准备工作,再按照说明中的启动方式启动训练,以下为具体的操作说明。 + +### 参考实现 + +DanceGRPO开源代码仓以及对应commit id如下: + +``` +url=https://github.com/XueZeyue/DanceGRPO +commit_id=2149f36f22db601f9dbf70472fea11576f62a0f6 +``` + + + +## 环境安装 + +【模型开发时推荐使用配套的环境版本】 + +请参考[安装指南](../../docs/user-guide/installation.md) + +> DanceGRPO场景下,Python版本推荐3.10 + + + + +### 1. 仓库拉取 + +```shell +git clone https://gitee.com/ascend/MindSpeed-MM.git +git clone https://github.com/NVIDIA/Megatron-LM.git +cd Megatron-LM +git checkout core_v0.12.1 +cp -r megatron ../MindSpeed-MM/ +cd .. + +cd MindSpeed-MM +mkdir -p logs data ckpt +cd .. +``` + + + +### 2. 环境搭建 + +```bash +# python3.10 +conda create -n test python=3.10 +conda activate test + +# for torch-npu dev version or x86 machine [Optional] +# pip config set global.extra-index-url "https://download.pytorch.org/whl/cpu/ https://mirrors.huaweicloud.com/ascend/repos/pypi" +# 安装torch和torch_npu +pip install torch-2.7.1+cpu-cp310-cp310-*.whl +pip install torch_npu-2.7.1*.whl + +# 安装加速库 +git clone https://gitee.com/ascend/MindSpeed.git +cd MindSpeed +git checkout 6d63944cb2470a0bebc38dfb65299b91329b8d92 +cp -r mindspeed ../MindSpeed-MM/ +cd .. + +# 安装dance grpo依赖库 +cd MindSpeed-MM +pip install -r ./examples/dancegrpo/requirements-lint.txt +cd .. + +git clone https://github.com/tgxs002/HPSv2.git +cd HPSv2 +git checkout 866735ecaae999fa714bd9edfa05aa2672669ee3 +pip install -e . +cd .. +``` + +### 3.Decord搭建 + +【X86版安装】 + +```bash +pip install decord==0.6.0 +``` + +【ARM版安装】 + +`apt`方式安装请[参考链接](https://github.com/dmlc/decord) + +`yum`方式安装请[参考脚本](https://github.com/dmlc/decord/blob/master/tools/build_manylinux2010.sh) + + + + +## 权重下载 +创建保存权重的目录: +```bash +cd MindSpeed-MM +mkdir ckpt/flux +mkdir ckpt/hps_ckpt +cd .. +``` + +下载FLUX预训练权重[FLUX预训练权重](https://huggingface.co/black-forest-labs/FLUX.1-dev) +,下载至MindSpeed MM工程根目录下的ckpt/flux目录中。 + +下载HPS-v2.1预训练权重[HPS-v2.1预训练权重](https://huggingface.co/xswu/HPSv2/tree/main) +,下载至MindSpeed MM工程根目录下的ckpt/hps_ckpt目录中。 + + + +## 数据集准备及处理 + +下载FLUX DanceGRPO使用的[提示词数据集](https://github.com/XueZeyue/DanceGRPO/blob/main/prompts.txt)。在文件页面点击download +raw file下载文件至MindSpeed MM工程根目录的data目录下。 + +数据集下载完成后要对数据进行预处理,在启动预处理之前,可以根据自身训练配置需要修改[ 数据预处理脚本 ](./preprocess_flux_rl_embeddings.sh)的配置,以FLUX模型为例: +1. vae模型权重所在路径为`LOAD_PATH`,默认为data/flux; +2. 预处理后的数据集存放路径为`OUTPUT_DIR`,默认为data/rl_embeddings; +3. 提示词文件路径为`PROMPT_DIR`,默认为data/prompts.txt。 + +上述注意点修改完毕后,可启动脚本进行数据预处理: +```bash +cd MindSpeed-MM +bash examples/dancegrpo/preprocess_flux_rl_embeddings.sh +``` + +处理后的数据默认会存储在MindSpeed MM根目录下的data/rl_embeddings目录中。 + + + +## 训练 + + + +### 1. 准备工作 + +配置脚本前需要完成前置准备工作,包括:**环境安装**、**权重下载**、**数据集准备及处理**,详情可查看对应章节。 + + + +### 2. 三方库修改 +找到使用的Python环境的根目录,对于使用conda安装的环境,可以使用如下指令找到: +```bash +echo $(conda info --envs | grep test) | awk '{print $NF}' +``` + +1. 将文件`lib/python3.10/site-packages/diffusers/models/embeddings.py`中`FluxPosEmbed`类的`forward`函数的如下代码: +```python +is_mps = ids.device.type == "mps" +freqs_dtype = torch.float32 if is_mps else torch.float64 +``` +修改为: +```python +is_mps = ids.device.type == "mps" +is_npu = ids.device.type == "npu" +freqs_dtype = torch.float32 if is_mps or is_npu else torch.float64 +``` + +2. 将文件`lib/python3.10/site-packages/diffusers/models/embeddings.py`中的`get_1d_rotary_pos_embed`函数的如下代码: +```python +freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] +freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] +``` +修改为: +```python +freqs_cos = freqs.cos().T.repeat_interleave(2, dim=0).T.contiguous().float() +freqs_sin = freqs.sin().T.repeat_interleave(2, dim=0).T.contiguous().float() +``` + +3. 将文件`lib/python3.10/site-packages/diffusers/models/attention_processor.py`中`Attention`类的`__init__`函数的如下代码: +```python +elif qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) +``` +修改为: +```python +elif qk_norm == "rms_norm": + self.norm_q = NpuFusedRMSNorm(dim_head, eps=eps) + self.norm_k = NpuFusedRMSNorm(dim_head, eps=eps) +``` + +增加如下类: +```python +class NpuFusedRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x): + return torch_npu.npu_rms_norm(x.to(self.weight.dtype), self.weight, epsilon=self.eps)[0] +``` + +### 3. 启动训练 + +以 FLUX 模型为例,在启动训练之前,可根据自身训练配置需要修改[ 启动脚本 ](./posttrain_flux_dancegrpo.sh)的配置: + +1. 根据使用机器的情况,修改`NNODES`、`NPUS_PER_NODE`配置, 例如单机8卡 可设置`NNODES`为 1 、`NPUS_PER_NODE`为8; +2. 如果为多机训练,需要保证各个节点的`MASTER_ADDR`一致,且为其中一台节点的IP;各节点的`MASTER_PORT` + 配置为相同端口号;从IP为MASTER_ADDR的节点开始,将各节点的`NODE_RANK`配置为从0开始依次递增的整数; +3. 数据集配置信息路径为`MM_DATA`,默认路径为./examples/dancegrpo/data_dancegrpo.json; +4. 模型配置信息路径为`MM_MODEL`,默认路径为./examples/dancegrpo/model_dancegrpo.json; +5. DiT模型预训练权重加载路径为`LOAD_PATH`,默认路径为data/flux,用户也可以根据自身权重存放位置进行调整; +6. 训练权重的保存路径为`SAVE_PATH`,默认为save_dir; +7. 模型训练过程的reward值保存文件的路径为`HPS_REWARD_SAVE_PATH`,默认为./hps_reward.txt。 + +在启动训练前,可根据自身训练配置需要修改数据集配置[data_dancegrpo.json](./data_dancegrpo.json): +1. dataset_param.basic_parameters.data_path表示预处理数据中的元数据文件videos2caption.json的路径。 + +在启动训练前,可根据自身训练配置需要修改模型配置[model_dancegrpo.json](./model_dancegrpo.json): +1. reward.ckpt_dir表示奖励模型预训练权重的路径。 + + +上述注意点修改完毕后,可启动脚本开启训练: +```bash +bash examples/dancegrpo/posttrain_flux_dancegrpo.sh +``` + +> *注意:所有节点的代码、权重、数据等路径的层级要保持一致,且启动训练脚本的时候都位于MindSpeed MM目录下* + +训练完成后,会在logs目录中生成运行日志文件,生成训练reward记录文件。 diff --git a/examples/dancegrpo/data_dancegrpo.json b/examples/dancegrpo/data_dancegrpo.json new file mode 100644 index 00000000..f7195963 --- /dev/null +++ b/examples/dancegrpo/data_dancegrpo.json @@ -0,0 +1,7 @@ +{ + "dataset_param": { + "basic_parameters": { + "data_path": "./data/rl_embeddings/videos2caption.json" + } + } +} \ No newline at end of file diff --git a/examples/dancegrpo/model_dancegrpo.json b/examples/dancegrpo/model_dancegrpo.json new file mode 100644 index 00000000..359897f6 --- /dev/null +++ b/examples/dancegrpo/model_dancegrpo.json @@ -0,0 +1,13 @@ +{ + "task": "t2i", + "reward": { + "model_id": "CLIP-ViT-H-14", + "model_name": "ViT-H-14", + "ckpt_dir": "./ckpt/hps_ckpt", + "pretrained": "open_clip_pytorch_model.bin", + "load_pt": "HPS_v2.1_compressed.pt" + }, + "diffusion": { + "model_id": "FLUX" + } +} \ No newline at end of file diff --git a/examples/dancegrpo/posttrain_flux_dancegrpo.sh b/examples/dancegrpo/posttrain_flux_dancegrpo.sh new file mode 100644 index 00000000..0782feb8 --- /dev/null +++ b/examples/dancegrpo/posttrain_flux_dancegrpo.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +source /usr/local/Ascend/ascend-toolkit/set_env.sh +# 该变量只用于规避megatron对其校验,对npu无效 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export ASCEND_SLOG_PRINT_TO_STDOUT=0 +export ASCEND_GLOBAL_LOG_LEVEL=3 +export TASK_QUEUE_ENABLE=2 +export COMBINED_ENABLE=1 +export CPU_AFFINITY_CONF=2 +export HCCL_CONNECT_TIMEOUT=1200 +export NPU_ASD_ENABLE=0 +export ASCEND_LAUNCH_BLOCKING=0 +export ACLNN_CACHE_LIMIT=100000 +export MULTI_STREAM_MEMORY_REUSE=2 +export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" + +export HCCL_BUFFSIZE=800 + +NPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + + +MM_DATA="./examples/dancegrpo/data_dancegrpo.json" +MM_MODEL="./examples/dancegrpo/model_dancegrpo.json" +MM_TOOL="./mindspeed_mm/tools/tools.json" +LOAD_PATH="ckpt/flux" +SAVE_PATH="save_dir" +HPS_REWARD_SAVE_PATH="./hps_reward.txt" + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GPT_ARGS=" + --seed 42 \ + --load $LOAD_PATH \ + --lr 1.0e-5 \ + --train-iters 300 \ + --weight-decay 0.0001 \ +" + +MM_ARGS=" + --mm-data $MM_DATA \ + --mm-model $MM_MODEL \ + --mm-tool $MM_TOOL +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 10000 \ + --eval-interval 10000 \ + --eval-iters 5000 \ + --save $SAVE_PATH \ + --ckpt-format torch \ +" + +GRPO_ARGS=" + --cache_dir data/.cache \ + --gradient_checkpointing \ + --train_batch_size 1 \ + --num_latent_t 1 \ + --sp_size 1 \ + --train_sp_batch_size 1 \ + --dataloader_num_workers 4 \ + --gradient_accumulation_steps 4 \ + --mixed_precision bf16 \ + --cfg 0.0 \ + --h 720 \ + --w 720 \ + --t 1 \ + --sampling_steps 16 \ + --eta 0.3 \ + --lr_warmup_steps 0 \ + --sampler_seed 1223627 \ + --max_grad_norm 1.0 \ + --use_hpsv2 \ + --num_generations 12 \ + --shift 3 \ + --use_group \ + --ignore_last \ + --timestep_fraction 0.6 \ + --init_same_noise \ + --clip_range 1e-4 \ + --adv_clip_max 5.0 \ + --hps_reward_save $HPS_REWARD_SAVE_PATH \ + --sample_batch_size 12 \ +" + +logfile=$(date +%Y%m%d)_$(date +%H%M%S) +mkdir -p logs +mkdir -p images +torchrun $DISTRIBUTED_ARGS posttrain_flux_dancegrpo.py \ + $GPT_ARGS \ + $MM_ARGS \ + $OUTPUT_ARGS \ + $GRPO_ARGS \ + --distributed-backend nccl \ + 2>&1 | tee logs/train_${logfile}.log +chmod 440 logs/train_${logfile}.log +chmod -R 640 $SAVE_PATH \ No newline at end of file diff --git a/examples/dancegrpo/preprocess_flux_rl_embeddings.sh b/examples/dancegrpo/preprocess_flux_rl_embeddings.sh new file mode 100644 index 00000000..d9c82483 --- /dev/null +++ b/examples/dancegrpo/preprocess_flux_rl_embeddings.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +NPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=19002 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +LOAD_PATH="ckpt/flux" +OUTPUT_DIR="data/rl_embeddings" +PROMPT_DIR="data/prompts.txt" + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +GRPO_ARGS=" + --load $LOAD_PATH \ + --output_dir $OUTPUT_DIR \ + --prompt_dir $PROMPT_DIR \ + --sample_num 50000 \ +" + +logfile=$(date +%Y%m%d)_$(date +%H%M%S) +mkdir -p logs +torchrun $DISTRIBUTED_ARGS mindspeed_mm/tasks/rl/soragrpo/preprocess/flux_data_preprocess.py \ + $GRPO_ARGS \ + 2>&1 | tee logs/preprocess_${logfile}.log diff --git a/examples/dancegrpo/requirements-lint.txt b/examples/dancegrpo/requirements-lint.txt new file mode 100644 index 00000000..8b9eefd4 --- /dev/null +++ b/examples/dancegrpo/requirements-lint.txt @@ -0,0 +1,76 @@ +# formatting +yapf==0.32.0 +toml==0.10.2 +tomli==2.0.2 +ruff==0.6.5 +codespell==2.3.0 +isort==5.13.2 +sphinx-lint==1.0.0 +torch==2.7.1 + +ml-collections==1.1.0 +absl-py==2.3.1 +inflect==6.0.4 +pydantic==2.9.2 +protobuf==5.29.5 +packaging==25.0 +ninja==1.11.1.4 +trl==0.19.1 + +accelerate==1.9.0 +bitsandbytes==0.46.1 +transformers==4.54.0 +tokenizers==0.21.2 +albumentations==1.4.20 +av==13.1.0 +einops==0.8.0 +fastapi==0.115.3 +gdown==5.2.0 +h5py==3.12.1 +idna==3.7 +imageio==2.36.0 +matplotlib==3.9.2 +numpy==1.26.3 +omegaconf==2.3.0 +opencv-python==4.10.0.84 +opencv-python-headless==4.10.0.84 +pandas==2.2.3 +pillow==10.3.0 +pydub==0.25.1 +pytorch-lightning==2.4.0 +pytorchvideo==0.1.5 +PyYAML==6.0.1 +regex==2024.9.11 +requests==2.32.4 +scikit-learn==1.5.2 +scipy==1.13.0 +six==1.16.0 +test-tube==0.7.5 +timm==1.0.11 +torchdiffeq==0.2.4 +torchmetrics==1.5.1 +tqdm==4.66.5 +urllib3==2.5.0 +uvicorn==0.32.0 +scikit-video==1.1.11 +imageio-ffmpeg==0.5.1 +sentencepiece==0.2.0 +beautifulsoup4==4.12.3 +ftfy==6.3.0 +moviepy==1.0.3 +wandb==0.18.5 +pydantic==2.9.2 +watch==0.2.7 +gpustat==1.1.1 +peft==0.13.2 +einops==0.8.0 +wheel==0.44.0 +loguru==0.7.3 +diffusers==0.32.0 + + +# type checking +mypy==1.11.1 +types-PyYAML +types-requests +types-setuptools diff --git a/posttrain_flux_dancegrpo.py b/posttrain_flux_dancegrpo.py new file mode 100644 index 00000000..523be4dd --- /dev/null +++ b/posttrain_flux_dancegrpo.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved. + +import mindspeed.megatron_adaptor +from mindspeed_mm.tasks.rl.soragrpo.dataset.latent_flux_rl_datasets import LatentDataset +from mindspeed_mm.tasks.rl.soragrpo.flux_grpo_trainer import FluxGRPOTrainer + + +def train_valid_test_datasets_provider(args): + """Build train, valid, and test datasets.""" + train_dataset = LatentDataset(args.mm.data.dataset_param.basic_parameters.data_path, args.num_latent_t, args.cfg) + return train_dataset + + +if __name__ == "__main__": + train_valid_test_datasets_provider.is_distributed = True + + trainer = FluxGRPOTrainer( + train_valid_test_dataset_provider=train_valid_test_datasets_provider, + ) + trainer.train() -- Gitee