代码拉取完成,页面将自动刷新
import argparse
import ast
import asyncio
import functools
import gc
import glob
import hashlib
import importlib
import json
import math
import os
import pathlib
import random
import re
import shutil
import time
from dataclasses import asdict, dataclass
from io import BytesIO
from pathlib import Path
from textwrap import dedent, indent
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import safetensors.torch
import toml
import torch
import torch_npu
import transformers
import voluptuous
from accelerate import Accelerator, InitProcessGroupKwargs
from huggingface_hub import hf_hub_download
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Required, Schema
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
IMAGE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
".bmp",
".PNG",
".JPG",
".JPEG",
".WEBP",
".BMP",
]
try:
import pillow_avif
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
except ImportError:
pass
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except ImportError:
pass
# JPEG-XL on Windows
try:
import pillow_jxl
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except ImportError:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
def load_image(image_path):
with Image.open(image_path) as image:
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# Loads an image. Return value numpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# Resize
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
image_height, image_width = image.shape[0:2]
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
image = image[:, p : p + reso[0]]
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
image = image[p : p + reso[1]]
crop_ltrb = BucketManager.get_crop_ltrb(reso, original_size)
if image.shape[0] != reso[1] and image.shape[1] != reso[0]:
raise ValueError(f"internal error, illegal trimmed size: {image.shape}, {reso}")
return image, original_size, crop_ltrb
class ImageInfo:
def __init__(
self,
image_key: str,
num_repeats: int,
caption: str,
is_reg: bool,
absolute_path: str,
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_original_size: Tuple[int, int] = (
None # original image size, not latents size
)
self.latents_crop_ltrb: Tuple[int, int] = (
None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: str = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
# SDXL, optional
self.text_encoder_outputs_npz: Optional[str] = None
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
if max_reso is not None:
if max_size < max_reso[0]:
raise ValueError(
"the max_size should be larger than the width of max_reso"
)
if max_size < max_reso[1]:
raise ValueError(
"the max_size should be larger than the height of max_reso"
)
if min_size is not None:
if max_size < min_size:
raise ValueError("the max_size should be larger than the min_size")
self.no_upscale = no_upscale
if max_reso is None:
self.max_reso = None
self.max_area = None
else:
self.max_reso = max_reso
self.max_area = max_reso[0] * max_reso[1]
self.min_size = min_size
self.max_size = max_size
self.reso_steps = reso_steps
self.resos = []
self.reso_to_id = {}
self.buckets = (
[]
) # On preprocessing (image key, image, original size, crop left / top) Key
def add_image(self, reso, image_or_info):
bucket_id = self.reso_to_id.get(reso, None)
if bucket_id is not None:
self.buckets[bucket_id].append(image_or_info)
else:
pass
def shuffle(self):
for bucket in self.buckets:
random.shuffle(bucket)
def sort(self):
# 解像度順にソートする(表示時、メタデータ格納時の見栄えをよくするためだけ)。bucketsも入れ替えてreso_to_idも振り直す
sorted_resos = self.resos.copy()
sorted_resos.sort()
sorted_buckets = []
sorted_reso_to_id = {}
for i, reso in enumerate(sorted_resos):
bucket_id = self.reso_to_id.get(reso, None)
if bucket_id is not None:
sorted_buckets.append(self.buckets[bucket_id])
sorted_reso_to_id[reso] = i
else:
pass
self.resos = sorted_resos
self.buckets = sorted_buckets
self.reso_to_id = sorted_reso_to_id
def make_buckets(self):
resos = make_bucket_resolutions(
self.max_reso, self.min_size, self.max_size, self.reso_steps
)
self.set_predefined_resos(resos)
def set_predefined_resos(self, resos):
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos)
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso):
if reso not in self.reso_to_id:
bucket_id = len(self.resos)
self.reso_to_id[reso] = bucket_id
self.resos.append(reso)
self.buckets.append([])
# print(reso, bucket_id, len(self.buckets))
def round_to_steps(self, x):
x = int(x + 0.5)
return x - x % self.reso_steps
def select_bucket(self, image_width, image_height):
aspect_ratio = image_width / image_height
if not self.no_upscale:
# 拡大および縮小を行う
# 同じaspect ratioがあるかもしれないので(fine tuningで、no_upscale=Trueで前処理した場合)、解像度が同じものを優先する
reso = (image_width, image_height)
if reso in self.predefined_resos_set:
pass
else:
ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(
ar_errors
).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id]
ar_reso = reso[0] / reso[1]
if aspect_ratio > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (
int(image_width * scale + 0.5),
int(image_height * scale + 0.5),
)
# print("use predef", image_width, image_height, reso, resized_size)
else:
# 縮小のみを行う
if image_width * image_height > self.max_area:
# 画像が大きすぎるのでアスペクト比を保ったまま縮小することを前提にbucketを決める
resized_width = math.sqrt(self.max_area * aspect_ratio)
resized_height = self.max_area / resized_width
if abs(resized_width / resized_height - aspect_ratio) >= 1e-2:
raise ValueError("aspect is illegal")
# リサイズ後の短辺または長辺をreso_steps単位にする:aspect ratioの差が少ないほうを選ぶ
# 元のbucketingと同じロジック
b_width_rounded = self.round_to_steps(resized_width)
b_height_in_wr = self.round_to_steps(b_width_rounded / aspect_ratio)
ar_width_rounded = b_width_rounded / b_height_in_wr
b_height_rounded = self.round_to_steps(resized_height)
b_width_in_hr = self.round_to_steps(b_height_rounded * aspect_ratio)
ar_height_rounded = b_width_in_hr / b_height_rounded
# print(b_width_rounded, b_height_in_wr, ar_width_rounded)
# print(b_width_in_hr, b_height_rounded, ar_height_rounded)
if abs(ar_width_rounded - aspect_ratio) < abs(
ar_height_rounded - aspect_ratio
):
resized_size = (
b_width_rounded,
int(b_width_rounded / aspect_ratio + 0.5),
)
else:
resized_size = (
int(b_height_rounded * aspect_ratio + 0.5),
b_height_rounded,
)
# print(resized_size)
else:
resized_size = (image_width, image_height) # リサイズは不要
# 画像のサイズ未満をbucketのサイズとする(paddingせずにcroppingする)
bucket_width = resized_size[0] - resized_size[0] % self.reso_steps
bucket_height = resized_size[1] - resized_size[1] % self.reso_steps
# print("use arbitrary", image_width, image_height, resized_size, bucket_width, bucket_height)
reso = (bucket_width, bucket_height)
self.add_if_new_reso(reso)
ar_error = (reso[0] / reso[1]) - aspect_ratio
return reso, resized_size, ar_error
@staticmethod
def get_crop_ltrb(bucket_reso: Tuple[int, int], image_size: Tuple[int, int]):
# Stability AIの前処理に合わせてcrop left/topを計算する。crop rightはflipのaugmentationのために求める
# Calculate crop left/top according to the preprocessing of Stability AI. Crop right is calculated for flip augmentation.
bucket_ar = bucket_reso[0] / bucket_reso[1]
image_ar = image_size[0] / image_size[1]
if bucket_ar > image_ar:
# bucketのほうが横長→縦を合わせる
resized_width = bucket_reso[1] * image_ar
resized_height = bucket_reso[1]
else:
resized_width = bucket_reso[0]
resized_height = bucket_reso[0] / image_ar
crop_left = (bucket_reso[0] - resized_width) // 2
crop_top = (bucket_reso[1] - resized_height) // 2
crop_right = crop_left + resized_width
crop_bottom = crop_top + resized_height
return crop_left, crop_top, crop_right, crop_bottom
class BucketBatchIndex(NamedTuple):
bucket_index: int
bucket_batch_size: int
batch_index: int
class AugHelper:
# albumentationsへの依存をなくしたがとりあえず同じinterfaceを持たせる
def __init__(self):
pass
def color_aug(self, image: np.ndarray):
hue_shift_limit = 8
# remove dependency to albumentations
if random.random() <= 0.33:
if random.random() > 0.5:
# hue shift
hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit)
if hue_shift < 0:
hue_shift = 180 + hue_shift
hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
else:
# random gamma
gamma = random.uniform(0.95, 1.05)
image = np.clip(image**gamma, 0, 255).astype(np.uint8)
return {"image": image}
def get_augmentor(
self, use_color_aug: bool
): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]:
return self.color_aug if use_color_aug else None
class BaseSubset:
def __init__(
self,
image_dir: Optional[str],
num_repeats: int,
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
random_crop: bool,
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
caption_prefix: Optional[str],
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
self.random_crop = random_crop
self.caption_dropout_rate = caption_dropout_rate
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.caption_prefix = caption_prefix
self.caption_suffix = caption_suffix
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる
self.img_count = 0
class DreamBoothSubset(BaseSubset):
def __init__(
self,
image_dir: str,
is_reg: bool,
class_tokens: Optional[str],
caption_extension: str,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
) -> None:
if image_dir is None:
raise ValueError("image_dir must be specified")
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
return NotImplemented
return self.image_dir == other.image_dir
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
self.debug_dataset = debug_dataset
self.subsets: List[Union[DreamBoothSubset, FineTuningSubset]] = []
self.token_padding_disabled = False
self.tag_frequency = {}
self.XTI_layers = None
self.token_strings = None
self.enable_bucket = False
self.bucket_manager: BucketManager = None # not initialized
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None
self.bucket_no_upscale = None
self.bucket_info = None # for metadata
self.tokenizer_max_length = (
self.tokenizers[0].model_max_length
if max_token_length is None
else max_token_length + 2
)
self.current_epoch: int = (
0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
)
self.current_step: int = 0
self.max_train_steps: int = 0
self.seed: int = 0
# augmentation
self.aug_helper = AugHelper()
self.image_transforms = IMAGE_TRANSFORMS
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {}
# caching
self.caching_mode = None # None, 'latents', 'text'
def set_seed(self, seed):
self.seed = seed
def set_caching_mode(self, mode):
self.caching_mode = mode
def set_current_epoch(self, epoch):
if (
not self.current_epoch == epoch
): # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
self.current_epoch = epoch
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
for caption in captions:
for tag in caption.split(","):
tag = tag.strip()
if tag:
tag = tag.lower()
frequency = frequency_for_dir.get(tag, 0)
frequency_for_dir[tag] = frequency + 1
def disable_token_padding(self):
self.token_padding_disabled = True
def enable_XTI(self, layers=None, token_strings=None):
self.XTI_layers = layers
self.token_strings = token_strings
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, subset: BaseSubset, caption):
# caption に prefix/suffix を付ける
if subset.caption_prefix:
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
is_drop_out = (
subset.caption_dropout_rate > 0
and random.random() < subset.caption_dropout_rate
)
is_drop_out = (
is_drop_out
or subset.caption_dropout_every_n_epochs > 0
and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
)
if is_drop_out:
caption = ""
else:
if (
subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
):
fixed_tokens = []
flex_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(
subset.keep_tokens_separator, 1
)
fixed_tokens = [
t.strip()
for t in fixed_part.split(subset.caption_separator)
if t.strip()
]
flex_tokens = [
t.strip()
for t in flex_part.split(subset.caption_separator)
if t.strip()
]
else:
tokens = [
t.strip()
for t in caption.strip().split(subset.caption_separator)
]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(
subset.token_warmup_step * self.max_train_steps
)
if (
subset.token_warmup_step
and self.current_step < subset.token_warmup_step
):
tokens_len = (
math.floor(
(self.current_step)
* (
(len(flex_tokens) - subset.token_warmup_min)
/ (subset.token_warmup_step)
)
)
+ subset.token_warmup_min
)
flex_tokens = flex_tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
return tokens
tags = []
for token in tokens:
if random.random() >= subset.caption_tag_dropout_rate:
tags.append(token)
return tags
if subset.shuffle_caption:
random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if isinstance(str_to, list):
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
def get_input_ids(self, caption, tokenizer=None):
if tokenizer is None:
tokenizer = self.tokenizers[0]
input_ids = tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=self.tokenizer_max_length,
return_tensors="pt",
).input_ids
if self.tokenizer_max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(
1,
self.tokenizer_max_length - tokenizer.model_max_length + 2,
tokenizer.model_max_length - 2,
): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(
1,
self.tokenizer_max_length - tokenizer.model_max_length + 2,
tokenizer.model_max_length - 2,
):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
if (
ids_chunk[-2] != tokenizer.eos_token_id
and ids_chunk[-2] != tokenizer.pad_token_id
):
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
return input_ids
def register_image(self, info: ImageInfo, subset: BaseSubset):
self.image_data[info.image_key] = info
self.image_to_subset[info.image_key] = subset
def make_buckets(self):
"""
bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
min_size and max_size are ignored when enable_bucket is False
"""
print("loading image sizes.")
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
if self.enable_bucket:
print("make buckets")
else:
print("prepare dataset")
# bucketを作成し、画像をbucketに振り分ける
if self.enable_bucket:
if (
self.bucket_manager is None
): # fine tuningの場合でmetadataに定義がある場合は、すでに初期化済み
self.bucket_manager = BucketManager(
self.bucket_no_upscale,
(self.width, self.height),
self.min_bucket_reso,
self.max_bucket_reso,
self.bucket_reso_steps,
)
if not self.bucket_no_upscale:
self.bucket_manager.make_buckets()
else:
print(
"min_bucket_reso and max_bucket_reso are ignored if bucket_no_upscale is set, because bucket reso is defined by image size automatically / bucket_no_upscaleが指定された場合は、bucketの解像度は画像サイズから自動計算されるため、min_bucket_resoとmax_bucket_resoは無視されます"
)
img_ar_errors = []
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, ar_error = (
self.bucket_manager.select_bucket(image_width, image_height)
)
# print(image_info.image_key, image_info.bucket_reso)
img_ar_errors.append(abs(ar_error))
self.bucket_manager.sort()
else:
self.bucket_manager = BucketManager(
False, (self.width, self.height), None, None, None
)
self.bucket_manager.set_predefined_resos(
[(self.width, self.height)]
) # ひとつの固定サイズbucketのみ
for image_info in self.image_data.values():
image_width, image_height = image_info.image_size
image_info.bucket_reso, image_info.resized_size, _ = (
self.bucket_manager.select_bucket(image_width, image_height)
)
for image_info in self.image_data.values():
for _ in range(image_info.num_repeats):
self.bucket_manager.add_image(
image_info.bucket_reso, image_info.image_key
)
# bucket情報を表示、格納する
if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print(
"number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)"
)
for i, (reso, bucket) in enumerate(
zip(self.bucket_manager.resos, self.bucket_manager.buckets)
):
count = len(bucket)
if count > 0:
self.bucket_info["buckets"][i] = {
"resolution": reso,
"count": len(bucket),
}
print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append(
BucketBatchIndex(bucket_index, self.batch_size, batch_index)
)
self.shuffle_buckets()
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
# set random seed for this epoch
random.seed(self.seed + self.current_epoch)
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
def verify_bucket_reso_steps(self, min_steps: int):
if (
self.bucket_reso_steps is not None
and self.bucket_reso_steps % min_steps != 0
):
raise ValueError(
f"bucket_reso_steps is {self.bucket_reso_steps}. it must be divisible by {min_steps}.\n"
+ f"bucket_reso_stepsが{self.bucket_reso_steps}です。{min_steps}で割り切れる必要があります"
)
def is_latent_cacheable(self):
return all(
[not subset.color_aug and not subset.random_crop for subset in self.subsets]
)
def is_text_encoder_output_cacheable(self):
return all(
[
not (
subset.caption_dropout_rate > 0
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
)
for subset in self.subsets
]
)
def cache_latents(
self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True
):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
print("caching latents.")
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
print("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
if not is_main_process: # store to info only
continue
cache_available = is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug
)
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if (
cache_to_disk and not is_main_process
): # if cache to disk, don't cache latents in non-main process, set to info only
return
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
print("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(
vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop
)
# weight_dtype
def cache_text_encoder_outputs(
self,
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk=False,
is_main_process=True,
):
if len(tokenizers) != 2:
raise ValueError("only support SDXL")
print("caching text encoder outputs.")
image_infos = list(self.image_data.values())
print("checking cache existence...")
image_infos_to_cache = []
for info in tqdm(image_infos):
# subset = self.image_to_subset[info.image_key]
if cache_to_disk:
te_out_npz = (
os.path.splitext(info.absolute_path)[0]
+ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX
)
info.text_encoder_outputs_npz = te_out_npz
if not is_main_process: # store to info only
continue
if os.path.exists(te_out_npz):
continue
image_infos_to_cache.append(info)
if (
cache_to_disk and not is_main_process
): # if cache to disk, don't cache latents in non-main process, set to info only
return
# prepare tokenizers and text encoders
for text_encoder in text_encoders:
text_encoder.to(device)
if weight_dtype is not None:
text_encoder.to(dtype=weight_dtype)
# create batch
batch = []
batches = []
for info in image_infos_to_cache:
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
if len(batch) >= self.batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches: call text encoder and cache outputs for memory or disk
print("caching text encoder outputs...")
for batch in tqdm(batches):
infos, input_ids1, input_ids2 = zip(*batch)
input_ids1 = torch.stack(input_ids1, dim=0)
input_ids2 = torch.stack(input_ids2, dim=0)
cache_batch_text_encoder_outputs(
infos,
tokenizers,
text_encoders,
self.max_token_length,
cache_to_disk,
input_ids1,
input_ids2,
weight_dtype,
)
def get_image_size(self, image_path):
with Image.open(image_path) as image:
return image.size
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
face_cx = face_cy = face_w = face_h = 0
if subset.face_crop_aug_range is not None:
tokens = os.path.splitext(os.path.basename(image_path))[0].split("_")
if len(tokens) >= 5:
face_cx = int(tokens[-4])
face_cy = int(tokens[-3])
face_w = int(tokens[-2])
face_h = int(tokens[-1])
return img, face_cx, face_cy, face_w, face_h
# いい感じに切り出す
def crop_target(self, subset: BaseSubset, image, face_cx, face_cy, face_w, face_h):
height, width = image.shape[0:2]
if height == self.height and width == self.width:
return image
# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
size = min(self.height, self.width) # 短いほう
min_scale = max(
self.height / height, self.width / width
) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(
1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[1]))
) # 指定した顔最小サイズ
max_scale = min(
1.0, max(min_scale, size / (face_size * subset.face_crop_aug_range[0]))
) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
scale = random.uniform(min_scale, max_scale)
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
if nh < self.height and nw < self.width:
raise ValueError(f"internal error. small scale {scale}, {width}*{height}")
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
# 顔を中心として448*640とかへ切り出す
for axis, (target_size, length, face_p) in enumerate(
zip((self.height, self.width), (height, width), (face_cy, face_cx))
):
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
if subset.random_crop:
# 背景も含めるために顔を中心に置く確率を高めつつずらす
im_range = max(
length - face_p, face_p
) # 画像の端から顔中心までの距離の長いほう
p1 = (
p1
+ (random.randint(0, im_range) + random.randint(0, im_range))
- im_range
) # -range ~ +range までのいい感じの乱数
else:
# range指定があるときのみ、すこしだけランダムに(わりと適当)
if subset.face_crop_aug_range[0] != subset.face_crop_aug_range[1]:
if face_size > size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
p1 = max(0, min(p1, length - target_size))
if axis == 0:
image = image[p1 : p1 + target_size, :]
else:
image = image[:, p1 : p1 + target_size]
return image
def __len__(self):
return self._length
def __getitem__(self, index):
bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index]
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
if (
self.caching_mode is not None
): # return batch for latents/text encoder outputs caching
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
loss_weights = []
captions = []
input_ids_list = []
input_ids2_list = []
latents_list = []
images = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs1_list = []
text_encoder_outputs2_list = []
text_encoder_pool2_list = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data.get(image_key, None)
subset = self.image_to_subset.get(image_key, None)
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False
flipped = (
subset.flip_aug and random.random() < 0.5
) # not flipped or flipped with 50% chance
# image/latentsを処理する
if image_info.latents is not None: # cache_latents=Trueの場合
original_size = image_info.latents_original_size
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
else:
latents = image_info.latents_flipped
image = None
elif (
image_info.latents_npz is not None
): # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents = (
load_latents_from_disk(image_info.latents_npz)
)
if flipped:
latents = flipped_latents
del flipped_latents
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
subset, image_info.absolute_path
)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop,
img,
image_info.bucket_reso,
image_info.resized_size,
)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(
subset, img, face_cx, face_cy, face_w, face_h
)
elif im_h > self.height or im_w > self.width:
if not subset.random_crop:
raise ValueError(
f"image too large, but cropping and bucketing are disabled : {image_info.absolute_path}"
)
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p : p + self.height]
if im_w > self.width:
p = random.randint(0, im_w - self.width)
img = img[:, p : p + self.width]
im_h, im_w = img.shape[0:2]
if im_h != self.height and im_w != self.width:
raise ValueError(
f"image size is small: {image_info.absolute_path}"
)
original_size = [im_w, im_h]
crop_ltrb = (0, 0, 0, 0)
# augmentation
aug = self.aug_helper.get_augmentor(subset.color_aug)
if aug is not None:
try:
img = aug(image=img)["image"]
except KeyError:
print("Warning: Augmentation didn't return an 'image' key")
if flipped:
img = img[
:, ::-1, :
].copy() # copy to avoid negative stride problem
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
target_size = (
(image.shape[2], image.shape[1])
if image is not None
else (latents.shape[2] * 8, latents.shape[1] * 8)
)
if not flipped:
crop_left_top = (crop_ltrb[0], crop_ltrb[1])
else:
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
flippeds.append(flipped)
# captionとtext encoder outputを処理する
caption = image_info.caption # default
if image_info.text_encoder_outputs1 is not None:
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
captions.append(caption)
elif image_info.text_encoder_outputs_npz is not None:
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = (
load_text_encoder_outputs_from_disk(
image_info.text_encoder_outputs_npz
)
)
text_encoder_outputs1_list.append(text_encoder_outputs1)
text_encoder_outputs2_list.append(text_encoder_outputs2)
text_encoder_pool2_list.append(text_encoder_pool2)
captions.append(caption)
else:
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join(
[f"{x}_{layer}" for x in self.token_strings]
)
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
if (
not self.token_padding_disabled
): # this option might be omitted in future
if self.XTI_layers:
token_caption = self.get_input_ids(
caption_layer, self.tokenizers[0]
)
else:
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
if len(self.tokenizers) > 1:
if self.XTI_layers:
token_caption2 = self.get_input_ids(
caption_layer, self.tokenizers[1]
)
else:
token_caption2 = self.get_input_ids(
caption, self.tokenizers[1]
)
input_ids2_list.append(token_caption2)
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if len(text_encoder_outputs1_list) == 0:
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer[0](
captions, padding=True, truncation=True, return_tensors="pt"
).input_ids
if len(self.tokenizers) > 1:
example["input_ids2"] = self.tokenizer[1](
captions, padding=True, truncation=True, return_tensors="pt"
).input_ids
else:
example["input_ids2"] = None
else:
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids2"] = (
torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
)
example["text_encoder_outputs1_list"] = None
example["text_encoder_outputs2_list"] = None
example["text_encoder_pool2_list"] = None
else:
example["input_ids"] = None
example["input_ids2"] = None
example["text_encoder_outputs1_list"] = torch.stack(
text_encoder_outputs1_list
)
example["text_encoder_outputs2_list"] = torch.stack(
text_encoder_outputs2_list
)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example["images"] = images
example["latents"] = (
torch.stack(latents_list) if latents_list[0] is not None else None
)
example["captions"] = captions
example["original_sizes_hw"] = torch.stack(
[torch.LongTensor(x) for x in original_sizes_hw]
)
example["crop_top_lefts"] = torch.stack(
[torch.LongTensor(x) for x in crop_top_lefts]
)
example["target_sizes_hw"] = torch.stack(
[torch.LongTensor(x) for x in target_sizes_hw]
)
example["flippeds"] = flippeds
example["network_multipliers"] = torch.FloatTensor(
[self.network_multiplier] * len(captions)
)
example["original_sizes"] = original_sizes_hw
example["crop_top_lefts_list"] = crop_top_lefts
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
captions = []
images = []
input_ids1_list = []
input_ids2_list = []
absolute_paths = []
resized_sizes = []
bucket_reso = None
flip_aug = None
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data.get(image_key, None)
subset = self.image_to_subset.get(image_key, None)
if flip_aug is None:
flip_aug = subset.flip_aug
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
if flip_aug != subset.flip_aug:
raise ValueError("flip_aug must be same in a batch")
if random_crop != subset.random_crop:
raise ValueError("random_crop must be same in a batch")
if bucket_reso != image_info.bucket_reso:
raise ValueError("bucket_reso must be same in a batch")
caption = (
image_info.caption
) # needs cache some patterns of dropping, shuffling, etc.
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
if self.caching_mode == "text":
input_ids1 = self.get_input_ids(caption, self.tokenizers[0])
input_ids2 = self.get_input_ids(caption, self.tokenizers[1])
else:
input_ids1 = None
input_ids2 = None
captions.append(caption)
images.append(image)
input_ids1_list.append(input_ids1)
input_ids2_list.append(input_ids2)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
example = {}
if images[0] is None:
images = None
example["images"] = images
example["captions"] = captions
example["input_ids1_list"] = input_ids1_list
example["input_ids2_list"] = input_ids2_list
example["absolute_paths"] = absolute_paths
example["resized_sizes"] = resized_sizes
example["flip_aug"] = flip_aug
example["random_crop"] = random_crop
example["bucket_reso"] = bucket_reso
return example
class DreamBoothDataset(BaseDataset):
def __init__(
self,
subsets: Sequence[DreamBoothSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
min_bucket_reso: int,
max_bucket_reso: int,
bucket_reso_steps: int,
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
) -> None:
super().__init__(
tokenizer, max_token_length, resolution, network_multiplier, debug_dataset
)
if resolution is None:
raise ValueError("resolution is required")
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.enable_bucket = enable_bucket
if self.enable_bucket:
if min(resolution) < min_bucket_reso:
raise ValueError(
"min_bucket_reso must be equal or less than resolution"
)
if max(resolution) > max_bucket_reso:
raise ValueError(
"max_bucket_reso must be equal or greater than resolution"
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [
base_name + caption_extension,
base_name_face_det + caption_extension,
]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
try:
lines = f.readlines()
except UnicodeDecodeError as e:
print(
f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}"
)
raise e
if len(lines) <= 0:
raise ValueError(f"caption file is empty: {cap_path}")
caption = lines[0].strip()
break
return caption
def load_dreambooth_dir(subset: DreamBoothSubset):
if not os.path.isdir(subset.image_dir):
print(f"not directory: {subset.image_dir}")
return [], []
img_paths = glob_images(subset.image_dir, "*")
print(
f"found directory {subset.image_dir} contains {len(img_paths)} image files"
)
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(
os.path.basename(subset.image_dir), captions
) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = (
number_of_missing_captions - number_of_missing_captions_to_show
)
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(
missing_caption
+ f"... and {remaining_missing_captions} more"
)
break
print(missing_caption)
return img_paths, captions
print("prepare images.")
num_train_images = 0
num_reg_images = 0
reg_infos: List[ImageInfo] = []
for subset in subsets:
if subset.num_repeats < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': num_repeats is less than 1 / num_repeatsが1を下回っているためサブセットを無視します: {subset.num_repeats}"
)
continue
if subset in self.subsets:
print(
f"ignore duplicated subset with image_dir='{subset.image_dir}': use the first one / 既にサブセットが登録されているため、重複した後発のサブセットを無視します"
)
continue
img_paths, captions = load_dreambooth_dir(subset)
if len(img_paths) < 1:
print(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
)
continue
if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(
img_path, subset.num_repeats, caption, subset.is_reg, img_path
)
if subset.is_reg:
reg_infos.append(info)
else:
self.register_image(info, subset)
subset.img_count = len(img_paths)
self.subsets.append(subset)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print(
"some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります"
)
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
n = 0
first_loop = True
while n < num_train_images:
for info in reg_infos:
if first_loop:
self.register_image(info, subset)
n += info.num_repeats
else:
info.num_repeats += 1 # rewrite registered info
n += 1
if n >= num_train_images:
break
first_loop = False
self.num_reg_images = num_reg_images
# behave as Dataset mock
class DatasetGroup(torch.utils.data.ConcatDataset):
def __init__(self, datasets: Sequence[DreamBoothDataset]):
self.datasets: List[DreamBoothDataset]
super().__init__(datasets)
self.image_data = {}
self.num_train_images = 0
self.num_reg_images = 0
# simply concat together
# needs handling image_data key duplication among dataset
# In practical, this is not the big issue because image_data is accessed from outside of dataset only for debug_dataset.
for dataset in datasets:
self.image_data.update(dataset.image_data)
self.num_train_images += dataset.num_train_images
self.num_reg_images += dataset.num_reg_images
def add_replacement(self, str_from, str_to):
for dataset in self.datasets:
dataset.add_replacement(str_from, str_to)
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
def cache_latents(
self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True
):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def cache_text_encoder_outputs(
self,
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk=False,
is_main_process=True,
):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_text_encoder_outputs(
tokenizers,
text_encoders,
device,
weight_dtype,
cache_to_disk,
is_main_process,
)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
def verify_bucket_reso_steps(self, min_steps: int):
for dataset in self.datasets:
dataset.verify_bucket_reso_steps(min_steps)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
def is_text_encoder_output_cacheable(self) -> bool:
return all(
[dataset.is_text_encoder_output_cacheable() for dataset in self.datasets]
)
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
# collate_fn use epoch,stepはmultiprocessing.Value
class collator_class:
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = (
dataset # not used if worker_info is not None, in case of multiprocessing
)
def __call__(self, examples):
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]
def add_config_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dataset_config",
type=Path,
default=None,
help="config file for detail settings / 詳細な設定用の設定ファイル",
)
# needs inherit Params class in Subset, Dataset
@dataclass
class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
caption_separator: tuple = (",",)
keep_tokens: int = 0
keep_tokens_separator: tuple = (None,)
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
caption_prefix: Optional[str] = None
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: int = 0
@dataclass
class DreamBoothSubsetParams(BaseSubsetParams):
is_reg: bool = False
class_tokens: Optional[str] = None
caption_extension: str = ".txt"
@dataclass
class FineTuningSubsetParams(BaseSubsetParams):
metadata_file: Optional[str] = None
@dataclass
class ControlNetSubsetParams(BaseSubsetParams):
conditioning_data_dir: str = None
caption_extension: str = ".caption"
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
batch_size: int = 1
enable_bucket: bool = False
min_bucket_reso: int = 256
max_bucket_reso: int = 1024
bucket_reso_steps: int = 64
bucket_no_upscale: bool = False
prior_loss_weight: float = 1.0
@dataclass
class BaseSubsetParams:
image_dir: Optional[str] = None
num_repeats: int = 1
shuffle_caption: bool = False
caption_separator: tuple = (",",)
keep_tokens: int = 0
keep_tokens_separator: tuple = (None,)
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
caption_prefix: Optional[str] = None
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: int = 0
@dataclass
class SubsetBlueprint:
params: DreamBoothSubsetParams
@dataclass
class DatasetBlueprint:
is_dreambooth: bool
is_controlnet: bool
params: DreamBoothDatasetParams
subsets: Sequence[SubsetBlueprint]
@dataclass
class DatasetGroupBlueprint:
datasets: Sequence[DatasetBlueprint]
@dataclass
class Blueprint:
dataset_group: DatasetGroupBlueprint
class ConfigSanitizer:
@staticmethod
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
Schema(ExactSequence([klass, klass]))(value)
return tuple(value)
@staticmethod
def __validate_and_convert_scalar_or_twodim(
klass, value: Union[float, Sequence]
) -> Tuple:
Schema(Any(klass, ExactSequence([klass, klass])))(value)
try:
Schema(klass)(value)
return (value, value)
except Exception as e:
if not isinstance(e, KeyboardInterrupt):
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
else:
raise e
# subset schema
SUBSET_ASCENDABLE_SCHEMA = {
"color_aug": bool,
"face_crop_aug_range": functools.partial(
__validate_and_convert_twodim.__func__, float
),
"flip_aug": bool,
"num_repeats": int,
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"token_warmup_min": int,
"token_warmup_step": Any(float, int),
"caption_prefix": str,
"caption_suffix": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
"caption_dropout_every_n_epochs": int,
"caption_dropout_rate": Any(float, int),
"caption_tag_dropout_rate": Any(float, int),
}
# DB means DreamBooth
DB_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
"class_tokens": str,
}
DB_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
"is_reg": bool,
}
# FT means FineTuning
FT_SUBSET_DISTINCT_SCHEMA = {
Required("metadata_file"): str,
"image_dir": str,
}
CN_SUBSET_ASCENDABLE_SCHEMA = {
"caption_extension": str,
}
CN_SUBSET_DISTINCT_SCHEMA = {
Required("image_dir"): str,
Required("conditioning_data_dir"): str,
}
# datasets schema
DATASET_ASCENDABLE_SCHEMA = {
"batch_size": int,
"bucket_no_upscale": bool,
"bucket_reso_steps": int,
"enable_bucket": bool,
"max_bucket_reso": int,
"min_bucket_reso": int,
"resolution": functools.partial(
__validate_and_convert_scalar_or_twodim.__func__, int
),
"network_multiplier": float,
}
# options handled by argparse but not handled by user config
ARGPARSE_SPECIFIC_SCHEMA = {
"debug_dataset": bool,
"max_token_length": Any(None, int),
"prior_loss_weight": Any(float, int),
}
# for handling default None value of argparse
ARGPARSE_NULLABLE_OPTNAMES = [
"face_crop_aug_range",
"resolution",
]
# prepare map because option name may differ among argparse and user config
ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME = {
"train_batch_size": "batch_size",
"dataset_repeats": "num_repeats",
}
def __init__(
self,
support_dreambooth: bool,
support_finetuning: bool,
support_controlnet: bool,
support_dropout: bool,
) -> None:
if not (support_dreambooth or support_finetuning or support_controlnet):
raise ValueError(
"Neither DreamBooth mode nor fine tuning mode specified. Please specify one mode or more."
)
self.db_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_DISTINCT_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.ft_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.FT_SUBSET_DISTINCT_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.cn_subset_schema = self.__merge_dict(
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_DISTINCT_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.db_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.db_subset_schema]},
)
self.ft_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.ft_subset_schema]},
)
self.cn_dataset_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.CN_SUBSET_ASCENDABLE_SCHEMA,
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
{"subsets": [self.cn_subset_schema]},
)
if support_dreambooth and support_finetuning:
def validate_flex_dataset(dataset_config: dict):
subsets_config = dataset_config.get("subsets", [])
if support_controlnet and all(
["conditioning_data_dir" in subset for subset in subsets_config]
):
return Schema(self.cn_dataset_schema)(dataset_config)
# check dataset meets FT style
# NOTE: all FT subsets should have "metadata_file"
elif all(["metadata_file" in subset for subset in subsets_config]):
return Schema(self.ft_dataset_schema)(dataset_config)
# check dataset meets DB style
# NOTE: all DB subsets should have no "metadata_file"
elif all(["metadata_file" not in subset for subset in subsets_config]):
return Schema(self.db_dataset_schema)(dataset_config)
else:
raise voluptuous.Invalid(
"DreamBooth subset and fine tuning subset cannot be mixed in the same dataset. Please split them into separate datasets. / DreamBoothのサブセットとfine tuninのサブセットを同一のデータセットに混在させることはできません。別々のデータセットに分割してください。"
)
self.dataset_schema = validate_flex_dataset
elif support_dreambooth:
self.dataset_schema = self.db_dataset_schema
elif support_finetuning:
self.dataset_schema = self.ft_dataset_schema
elif support_controlnet:
self.dataset_schema = self.cn_dataset_schema
self.general_schema = self.__merge_dict(
self.DATASET_ASCENDABLE_SCHEMA,
self.SUBSET_ASCENDABLE_SCHEMA,
self.DB_SUBSET_ASCENDABLE_SCHEMA if support_dreambooth else {},
self.CN_SUBSET_ASCENDABLE_SCHEMA if support_controlnet else {},
self.DO_SUBSET_ASCENDABLE_SCHEMA if support_dropout else {},
)
self.user_config_validator = Schema(
{
"general": self.general_schema,
"datasets": [self.dataset_schema],
}
)
self.argparse_schema = self.__merge_dict(
self.general_schema,
self.ARGPARSE_SPECIFIC_SCHEMA,
{
optname: Any(None, self.general_schema.get(optname, None))
for optname in self.ARGPARSE_NULLABLE_OPTNAMES
},
{
a_name: self.general_schema.get(c_name, None)
for a_name, c_name in self.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME.items()
},
)
self.argparse_config_validator = Schema(
Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA
)
def sanitize_user_config(self, user_config: dict) -> dict:
try:
return self.user_config_validator(user_config)
except MultipleInvalid:
print("Invalid user config / ユーザ設定の形式が正しくないようです")
raise
# NOTE: In nature, argument parser result is not needed to be sanitize
# However this will help us to detect program bug
def sanitize_argparse_namespace(
self, argparse_namespace: argparse.Namespace
) -> argparse.Namespace:
try:
return self.argparse_config_validator(argparse_namespace)
except MultipleInvalid:
# Found a bug
print(
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
)
raise
# NOTE: value would be overwritten by latter dict if there is already the same key
@staticmethod
def __merge_dict(*dict_list: dict) -> dict:
merged = {}
for schema in dict_list:
# merged |= schema
for k, v in schema.items():
merged[k] = v
return merged
class BlueprintGenerator:
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
def __init__(self, sanitizer: ConfigSanitizer):
self.sanitizer = sanitizer
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
def generate(
self,
user_config: dict,
argparse_namespace: argparse.Namespace,
**runtime_params,
) -> Blueprint:
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(
argparse_namespace
)
# convert argparse namespace to dict like config
# NOTE: it is ok to have extra entries in dict
optname_map = self.sanitizer.ARGPARSE_OPTNAME_TO_CONFIG_OPTNAME
argparse_config = {
optname_map.get(optname, optname): value
for optname, value in vars(sanitized_argparse_namespace).items()
}
general_config = sanitized_user_config.get("general", {})
dataset_blueprints = []
for dataset_config in sanitized_user_config.get("datasets", []):
# NOTE: if subsets have no "metadata_file", these are DreamBooth datasets/subsets
subsets = dataset_config.get("subsets", [])
is_dreambooth = all(["metadata_file" not in subset for subset in subsets])
is_controlnet = False
if is_controlnet:
raise NotImplementedError("Not support now")
elif is_dreambooth:
subset_params_klass = DreamBoothSubsetParams
dataset_params_klass = DreamBoothDatasetParams
else:
raise NotImplementedError("Not support now")
subset_blueprints = []
for subset_config in subsets:
params = self.generate_params_by_fallbacks(
subset_params_klass,
[
subset_config,
dataset_config,
general_config,
argparse_config,
runtime_params,
],
)
subset_blueprints.append(SubsetBlueprint(params))
params = self.generate_params_by_fallbacks(
dataset_params_klass,
[dataset_config, general_config, argparse_config, runtime_params],
)
dataset_blueprints.append(
DatasetBlueprint(
is_dreambooth, is_controlnet, params, subset_blueprints
)
)
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
return Blueprint(dataset_group_blueprint)
@staticmethod
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
search_value = BlueprintGenerator.search_value
default_params = asdict(param_klass())
param_names = default_params.keys()
params = {
name: search_value(
name_map.get(name, name), fallbacks, default_params.get(name)
)
for name in param_names
}
return param_klass(**params)
@staticmethod
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
for cand in fallbacks:
value = cand.get(key)
if value is not None:
return value
return default_value
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = max_width * max_height
resos = set()
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
width += divisible
resos_list = list(resos)
resos_list.sort()
return resos_list
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
datasets: List[DreamBoothDataset] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.is_controlnet:
raise NotImplementedError("Not support now")
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
raise NotImplementedError("Not support now")
subsets = [
subset_klass(**asdict(subset_blueprint.params))
for subset_blueprint in dataset_blueprint.subsets
]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
datasets.append(dataset)
# print info
info = ""
for i, dataset in enumerate(datasets):
is_dreambooth = isinstance(dataset, DreamBoothDataset)
is_controlnet = False
info += dedent(
f"""\
[Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)
if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"
for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
caption_prefix: {subset.caption_prefix}
caption_suffix: {subset.caption_suffix}
color_aug: {subset.color_aug}
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
),
" ",
)
if is_dreambooth:
info += indent(
dedent(
f"""\
is_reg: {subset.is_reg}
class_tokens: {subset.class_tokens}
caption_extension: {subset.caption_extension}
\n"""
),
" ",
)
elif not is_controlnet:
info += indent(
dedent(
f"""\
metadata_file: {subset.metadata_file}
\n"""
),
" ",
)
print(info)
# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
for i, dataset in enumerate(datasets):
print(f"[Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)
return DatasetGroup(datasets)
def glob_images(directory, base="*"):
img_paths = []
for ext in IMAGE_EXTENSIONS:
if base == "*":
img_paths.extend(
glob.glob(os.path.join(glob.escape(directory), base + ext))
)
else:
img_paths.extend(
glob.glob(glob.escape(os.path.join(directory, base + ext)))
)
img_paths = list(set(img_paths))
img_paths.sort()
return img_paths
def load_tokenizer(args: argparse.Namespace):
print("prepare tokenizer")
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(
args.tokenizer_cache_dir, original_path.replace("/", "_")
)
if os.path.exists(local_tokenizer_path):
print(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(
local_tokenizer_path, local_files_only=True
) # same for v1 and v2
if tokenizer is None:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(
original_path, subfolder="tokenizer", local_files_only=True
)
else:
tokenizer = CLIPTokenizer.from_pretrained(original_path, local_files_only=True)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
print(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def generate_dreambooth_subsets_config_by_subdirs(
train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None
):
def extract_dreambooth_params(name: str) -> Tuple[int, str]:
tokens = name.split("_")
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(
f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}"
)
return 0, ""
caption_by_folder = "_".join(tokens[1:])
return n_repeats, caption_by_folder
def generate(base_dir: Optional[str], is_reg: bool):
if base_dir is None:
return []
base_dir: Path = Path(base_dir)
if not base_dir.is_dir():
return []
subsets_config = []
for subdir in base_dir.iterdir():
if not subdir.is_dir():
continue
num_repeats, class_tokens = extract_dreambooth_params(subdir.name)
if num_repeats < 1:
continue
subset_config = {
"image_dir": str(subdir),
"num_repeats": num_repeats,
"is_reg": is_reg,
"class_tokens": class_tokens,
}
subsets_config.append(subset_config)
if subsets_config == []:
subset_config = {
"image_dir": str(base_dir),
"num_repeats": 1,
"is_reg": is_reg,
"class_tokens": str(base_dir),
}
subsets_config.append(subset_config)
return subsets_config
subsets_config = []
subsets_config += generate(train_data_dir, False)
subsets_config += generate(reg_data_dir, True)
return subsets_config
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。