简体中文 |
扩散模型(Diffusion Models)是一种生成模型,可生成各种各样的高分辨率图像。Diffusers 是 HuggingFace 发布的模型套件,是最先进的预训练扩散模型的首选库,用于生成图像,音频,甚至分子的3D结构。套件包含基于扩散模型的多种模型,提供了各种下游任务的训练与推理的实现。
参考实现:
url=https://github.com/huggingface/diffusers
commit_id=5956b68a6927126daffc2c5a6d1a9a189defe288
【模型开发时推荐使用配套的环境版本】
软件 | 版本 | 安装指南 |
---|---|---|
Python | 3.8 | |
Driver | AscendHDK 24.1.RC3 | 《驱动固件安装指南 》 |
Firmware | AscendHDK 24.1.RC3 | |
CANN | CANN 8.0.RC3 | 《CANN 软件安装指南 》 |
Torch | 2.1.0 | 《Ascend Extension for PyTorch 配置与安装 》 |
Torch_npu | release v6.0.RC3 |
torch npu 与 CANN包参考链接:安装包参考链接
```bash
# python3.8
conda create -n test python=3.8
conda activate test
# 安装 torch 和 torch_npu,注意要选择对应python版本、x86或arm的torch、torch_npu及apex包
pip install torch-2.1.0-cp38-cp38m-manylinux2014_aarch64.whl
pip install torch_npu-2.1.0*-cp38-cp38m-linux_aarch64.whl
# apex for Ascend 参考 https://gitee.com/ascend/apex
pip install apex-0.1_ascend*-cp38-cp38m-linux_aarch64.whl
# 将shell脚本中的环境变量路径修改为真实路径,下面为参考路径
source /usr/local/Ascend/ascend-toolkit/set_env.sh
```
克隆仓库到本地服务器
git clone --branch 1.0.RC3 https://gitee.com/ascend/MindSpeed-MM.git
模型搭建
3.1 【下载 SDXL GitHub参考实现或 在模型根目录下执行以下命令,安装模型对应PyTorch版本需要的依赖】
git clone https://github.com/huggingface/diffusers.git -b v0.30.0
cd diffusers
git checkout 5956b68a6927126daffc2c5a6d1a9a189defe288
cp -r ../MindSpeed-MM/examples/diffusers/sdxl ./sdxl
【主要代码路径】
code_path=examples/text_to_image/
3.2【安装 {任务pretrain/train}_sdxl_deepspeed_{混精fp16/bf16}.sh
】
转移 collect_dataset.py
与 pretrain_model.py
与 train_text_to_image_sdxl_pretrain.py
与 patch_sdxl.py
到 examples/text_to_image/
路径
# Example: 需要修改.py名字进行四次任务
cp ./sdxl/train_text_to_image_sdxl_pretrain.py ./examples/text_to_image/
3.3【安装其余依赖库】
pip install -e .
vim examples/text_to_image/requirements_sdxl.txt #修改torchvision版本:torchvision==0.16.0, torch==2.1.0
pip install -r examples/text_to_image/requirements_sdxl.txt # 安装diffusers原仓对应依赖
pip install -r sdxl/requirements_sdxl_extra.txt #安装sdxl对应依赖
【准备预训练数据集】
用户需自行获取并解压laion_sx数据集(目前数据集暂已下架,可选其他数据集)与pokemon-blip-captions数据集,并在以下启动shell脚本中将dataset_name
参数设置为本地数据集的绝对路径
修改pretrain_sdxl_deepspeed_**16.sh
的dataset_name为laion_sx
的绝对路径
vim sdxl/pretrain_sdxl_deepspeed_**16.sh
修改train_sdxl_deepspeed_**16.sh
的dataset_name为pokemon-blip-captions
的绝对路径
vim sdxl/train_sdxl_deepspeed_**16.sh
laion_sx数据集格式如下:
laion_sx数据集格式如下
├── 000000000.jpg
├── 000000000.json
├── 000000000.txt
pokemon-blip-captions数据集格式如下:
pokemon-blip-captions
├── dataset_infos.json
├── README.MD
└── data
└── train-001.parquet
说明: 该数据集的训练过程脚本只作为一种参考示例。
【配置 SDXL 预训练脚本与预训练模型】
联网情况下,预训练模型可通过以下步骤下载。无网络时,用户可访问huggingface官网自行下载sdxl-base模型 model_name
模型与sdxl-vae模型 vae_name
export model_name="stabilityai/stable-diffusion-xl-base-1.0" # 预训练模型路径
export vae_name="madebyollin/sdxl-vae-fp16-fix" # vae模型路径
获取对应的预训练模型后,在以下shell启动脚本中将model_name
参数设置为本地预训练模型绝对路径,将vae_name
参数设置为本地vae
模型绝对路径
scripts_path="./sdxl" # 模型根目录(模型文件夹名称)
model_name="stabilityai/stable-diffusion-xl-base-1.0" # 预训练模型路径
vae_name="madebyollin/sdxl-vae-fp16-fix" # vae模型路径
dataset_name="laion_sx" # 数据集路径
batch_size=4
max_train_steps=2000
mixed_precision="bf16" # 混精
resolution=1024
config_file="${scripts_path}/pretrain_${mixed_precision}_accelerate_config.yaml"
修改bash文件中accelerate
配置下train_text_to_image_sdxl_pretrain.py
的路径(默认路径在diffusers/sdxl/)
accelerate launch --config_file ${config_file} \
${scripts_path}/train_text_to_image_sdxl_pretrain.py \ #如模型根目录为sdxl则无需修改
修改pretrain_fp16_accelerate_config.yaml
的deepspeed_config_file
的路径:
deepspeed_config_file: ./sdxl/deepspeed_fp16.json # deepspeed JSON文件路径
修改examples/text_to_image/train_text_to_image_sdxl.py
文件
vim examples/text_to_image/train_text_to_image_sdxl.py
在文件58行修改修改version
# 讲minimum version从0.31.0修改为0.30.0
check_min_version("0.30.0")
在文件59行添加代码
from patch_sdxl import TorchPatcher, compute_vae_encode, config_gc
TorchPatcher.apply_patch()
config_gc()
【Optional】在文件918行左右将compute_vae_encodings_fn
进行修改
compute_vae_encodings_fn = functools.partial(compute_vae_encode, accelerator=accelerator, vae=vae)
【Optional】Ubuntu系统需在文件1216行附近添加 accelerator.print("")
if global_step >= args.max_train_steps:
break
accelerator.print("")
【FPS打印方式请参考train_text_to_image_sdxl_pretrain.py】
【启动 SDXL 预训练脚本】
本任务主要提供混精fp16和混精bf16两种8卡训练脚本,默认使用deepspeed分布式训练。
pretrain模型主要来承担第二阶段的文生图的训练 train模型主要来承担第一阶段的文生图的训练功能
bash sdxl/pretrain_sdxl_deepspeed_**16.sh
bash sdxl/train_sdxl_deepspeed_**16.sh
SDXL 在 昇腾芯片 和 参考芯片 上的性能对比:
芯片 | 卡数 | 任务 | FPS | batch_size | AMP_Type | Torch_Version | deepspeed |
---|---|---|---|---|---|---|---|
竞品A | 8p | SDXL_train_bf16 | 30.65 | 4 | bf16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | SDXL_train_bf16 | 29.92 | 4 | bf16 | 2.1 | ✔ |
竞品A | 8p | SDXL_train_fp16 | 30.23 | 4 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | SDXL_train_fp16 | 28.51 | 4 | fp16 | 2.1 | ✔ |
竞品A | 8p | SDXL_pretrain_bf16 | 21.14 | 4 | bf16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | SDXL_pretrain_bf16 | 19.79 | 4 | bf16 | 2.1 | ✔ |
竞品A | 8p | SDXL_pretrain_fp16 | 20.77 | 4 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | SDXL_pretrain_fp16 | 19.67 | 4 | fp16 | 2.1 | ✔ |
说明: 环境搭建同预训练。数据集同预训练的
pokemon-blip-captions
,请参考预训练章节。
sdxl/finetune_sdxl_lora_deepspeed_fp16.sh
用户需自行获取fill50k数据集,并在以下启动shell脚本中将dataset_name
参数设置为本地数据集的绝对路径,以及需要修改里面fill50k.py文件
sdxl/finetune_sdxl_controlnet_deepspeed_fp16.sh
参考如下修改controlnet/train_controlnet_sdxl.py, 追加trust_remote_code=True
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
trust_remote_code=True
)
注意: 需要修改数据集下面的fill50k.py文件中的57到59行,修改示例如下:
metadata_path = "数据集路径/fill50k/train.jsonl" images_dir = "数据集路径/fill50k" conditioning_images_dir = "数据集路径/fill50k"
fill50k数据集格式如下:
fill50k
├── images
├── conditioning_images
├── train.jsonl
└── fill50k.py
说明: 该数据集的训练过程脚本只作为一种参考示例。
说明: 数据集同Lora微调,请参考Lora章节。
获取sdxl-base模型 model_name
模型与sdxl-vae模型 vae_name
。
获取对应的预训练模型后,在Controlnet微调
shell启动脚本中将model_name
参数设置为本地预训练模型绝对路径,将vae_name
参数设置为本地vae
模型绝对路径。
sdxl/finetune_sdxl_controlnet_deepspeed_fp16.sh
Lora微调
与全参微调
shell启动脚本中将model_name
参数设置为本地预训练模型绝对路径
sdxl/finetune_sdxl_deepspeed_fp16.sh
sdxl/finetune_sdxl_lora_deepspeed_fp16.sh
说明: 预训练模型同预训练,请参考预训练章节。
【Optional】如是Ubuntu系统需在 examples/text_to_image/train_text_to_image_lora_sdxl.py
与 examples/controlnet/train_controlnet_sdxl.py
添加 accelerator.print("")
:参考
注意 train_text_to_image_lora_sdxl 在1235行附近添加; train_controlnet_sdxl 在1307行附近添加
【运行微调的脚本】
```shell
# 单机八卡微调
# finetune_sdxl_controlnet_deepspeed_fp16.sh 中依赖的图片,可以通过下面命令下载
# wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
# wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
bash sdxl/finetune_sdxl_controlnet_deepspeed_fp16.sh #8卡deepspeed训练 sdxl_controlnet fp16
bash sdxl/finetune_sdxl_lora_deepspeed_fp16.sh #8卡deepspeed训练 sdxl_lora fp16
bash sdxl/finetune_sdxl_deepspeed_fp16.sh #8卡deepspeed训练 sdxl_finetune fp16
```
芯片 | 卡数 | 任务 | FPS | batch_size | AMP_Type | Torch_Version | deepspeed |
---|---|---|---|---|---|---|---|
竞品A | 8p | LoRA | 31.74 | 7 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | LoRA | 26.40 | 7 | fp16 | 2.1 | ✔ |
竞品A | 8p | Controlnet | 32.44 | 5 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | Controlnet | 29.98 | 5 | fp16 | 2.1 | ✔ |
竞品A | 8p | Finetune | 164.66 | 24 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | Finetune | 166.71 | 24 | fp16 | 2.1 | ✔ |
同微调对应章节
【运行推理的脚本】
单机单卡推理,脚本配置
调用推理脚本
python sdxl/sdxl_text2img_lora_infer.py # 混精fp16 文生图lora微调任务推理
python sdxl/sdxl_text2img_controlnet_infer.py # 混精fp16 文生图controlnet微调任务推理
python sdxl/sdxl_text2img_infer.py # 混精fp16 文生图全参微调任务推理
python sdxl/sdxl_img2img_infer.py # 混精fp16 图生图微调任务推理
芯片 | 卡数 | 任务 | E2E(it/s) | AMP_Type | Torch_Version | deepspeed |
---|---|---|---|---|---|---|
竞品A | 8p | 文生图lora | 1.45 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | 文生图lora | 2.61 | fp16 | 2.1 | ✔ |
竞品A | 8p | 文生图controlnet | 1.41 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | 文生图controlnet | 2.97 | fp16 | 2.1 | ✔ |
竞品A | 8p | 文生图全参 | 1.55 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | 文生图全参 | 3.02 | fp16 | 2.1 | ✔ |
竞品A | 8p | 图生图 | 3.56 | fp16 | 2.1 | ✔ |
Atlas 900 A2 PODc | 8p | 图生图 | 3.94 | fp16 | 2.1 | ✔ |
代码涉及公网地址参考 docs/public_address_statement.md
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。