KohyaSS/train_network.py

1453 lines
65 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gc
import importlib
import json
import time
from typing import NamedTuple
from torch.autograd.function import Function
import argparse
import glob
import math
import os
import random
from tqdm import tqdm
import torch
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import set_seed
from transformers import CLIPTokenizer
import diffusers
from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
from PIL import Image
import cv2
from einops import rearrange
from torch import einsum
import library.model_util as model_util
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
# checkpointファイル名
EPOCH_STATE_NAME = "epoch-{:06d}-state"
LAST_STATE_NAME = "last-state"
EPOCH_FILE_NAME = "epoch-{:06d}"
LAST_FILE_NAME = "last"
# region dataset
class ImageInfo():
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: tuple[int, int] = None
self.bucket_reso: tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_npz_flipped: str = None
class BucketBatchIndex(NamedTuple):
bucket_index: int
batch_index: int
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None:
super().__init__()
self.tokenizer: CLIPTokenizer = tokenizer
self.max_token_length = max_token_length
self.shuffle_caption = shuffle_caption
self.shuffle_keep_tokens = shuffle_keep_tokens
self.width, self.height = resolution
self.face_crop_aug_range = face_crop_aug_range
self.flip_aug = flip_aug
self.color_aug = color_aug
self.debug_dataset = debug_dataset
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
# わりと弱めの色合いaugmentationbrightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る
self.aug = albu.Compose([
albu.OneOf([
albu.HueSaturationValue(8, 0, 0, p=.5),
albu.RandomGamma((95, 105), p=.5),
], p=.33),
albu.HorizontalFlip(p=flip_p)
], p=1.)
elif flip_aug:
self.aug = albu.Compose([
albu.HorizontalFlip(p=flip_p)
], p=1.)
else:
self.aug = None
self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ])
self.image_data: dict[str, ImageInfo] = {}
def process_caption(self, caption):
if self.shuffle_caption:
tokens = caption.strip().split(",")
if self.shuffle_keep_tokens is None:
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ",".join(tokens).strip()
return caption
def get_input_ids(self, caption):
input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
if self.tokenizer_max_length > self.tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (input_ids[0].unsqueeze(0),
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0))
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
ids_chunk = (input_ids[0].unsqueeze(0), # BOS
input_ids[i:i + self.tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0)) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
ids_chunk[-1] = self.tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == self.tokenizer.pad_token_id:
ids_chunk[1] = self.tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
return input_ids
def register_image(self, info: ImageInfo):
self.image_data[info.image_key] = info
def make_buckets(self, enable_bucket, min_size, max_size):
'''
bucketingを行わない場合も呼び出し必須ひとつだけbucketを作る
min_size and max_size are ignored when enable_bucket is False
'''
self.enable_bucket = enable_bucket
print("loading image sizes.")
for info in tqdm(self.image_data.values()):
if info.image_size is None:
info.image_size = self.get_image_size(info.absolute_path)
if enable_bucket:
print("make buckets")
else:
print("prepare dataset")
# bucketingを用意する
if enable_bucket:
bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size)
else:
# bucketはひとつだけ、すべての画像は同じ解像度
bucket_resos = [(self.width, self.height)]
bucket_aspect_ratios = [self.width / self.height]
bucket_aspect_ratios = np.array(bucket_aspect_ratios)
# bucketを作成する
if enable_bucket:
img_ar_errors = []
for image_info in self.image_data.values():
# bucketを決める
image_width, image_height = image_info.image_size
aspect_ratio = image_width / image_height
ar_errors = bucket_aspect_ratios - aspect_ratio
bucket_id = np.abs(ar_errors).argmin()
image_info.bucket_reso = bucket_resos[bucket_id]
ar_error = ar_errors[bucket_id]
img_ar_errors.append(ar_error)
else:
reso = (self.width, self.height)
for image_info in self.image_data.values():
image_info.bucket_reso = reso
# 画像をbucketに分割する
self.buckets: list[str] = [[] for _ in range(len(bucket_resos))]
reso_to_index = {}
for i, reso in enumerate(bucket_resos):
reso_to_index[reso] = i
for image_info in self.image_data.values():
bucket_index = reso_to_index[image_info.bucket_reso]
for _ in range(image_info.num_repeats):
self.buckets[bucket_index].append(image_info.image_key)
if enable_bucket:
print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数DreamBoothの場合は繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
# 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.buckets):
batch_count = int(math.ceil(len(bucket) / self.batch_size))
for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index))
self.shuffle_buckets()
self._length = len(self.buckets_indices)
def shuffle_buckets(self):
random.shuffle(self.buckets_indices)
for bucket in self.buckets:
random.shuffle(bucket)
def load_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
def resize_and_trim(self, image, reso):
image_height, image_width = image.shape[0:2]
ar_img = image_width / image_height
ar_reso = reso[0] / reso[1]
if ar_img > ar_reso: # 横が長い→縦を合わせる
scale = reso[1] / image_height
else:
scale = reso[0] / image_width
resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size//2:trim_size//2 + reso[0]]
elif resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size//2:trim_size//2 + reso[1]]
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image
def cache_latents(self, vae):
print("caching latents.")
for info in tqdm(self.image_data.values()):
if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True)
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
image = self.load_image(info.absolute_path)
image = self.resize_and_trim(image, info.bucket_reso)
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
if self.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
def load_image_with_face_info(self, image_path: str):
img = self.load_image(image_path)
face_cx = face_cy = face_w = face_h = 0
if self.face_crop_aug_range is not None:
tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
if len(tokens) >= 5:
face_cx = int(tokens[-4])
face_cy = int(tokens[-3])
face_w = int(tokens[-2])
face_h = int(tokens[-1])
return img, face_cx, face_cy, face_w, face_h
# いい感じに切り出す
def crop_target(self, image, face_cx, face_cy, face_w, face_h):
height, width = image.shape[0:2]
if height == self.height and width == self.width:
return image
# 画像サイズはsizeより大きいのでリサイズする
face_size = max(face_w, face_h)
min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
if min_scale >= max_scale: # range指定がmin==max
scale = min_scale
else:
scale = random.uniform(min_scale, max_scale)
nh = int(height * scale + .5)
nw = int(width * scale + .5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
face_cx = int(face_cx * scale + .5)
face_cy = int(face_cy * scale + .5)
height, width = nh, nw
# 顔を中心として448*640とかへ切り出す
for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
if self.random_crop:
# 背景も含めるために顔を中心に置く確率を高めつつずらす
range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
else:
# range指定があるときのみ、すこしだけランダムにわりと適当
if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
if face_size > self.size // 10 and face_size >= 40:
p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
p1 = max(0, min(p1, length - target_size))
if axis == 0:
image = image[p1:p1 + target_size, :]
else:
image = image[:, p1:p1 + target_size]
return image
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
return np.load(npz_file)['arr_0']
def __len__(self):
return self._length
def __getitem__(self, index):
if index == 0:
self.shuffle_buckets()
bucket = self.buckets[self.buckets_indices[index].bucket_index]
image_index = self.buckets_indices[index].batch_index * self.batch_size
loss_weights = []
captions = []
input_ids_list = []
latents_list = []
images = []
for image_key in bucket[image_index:image_index + self.batch_size]:
image_info = self.image_data[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# image/latentsを処理する
if image_info.latents is not None:
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
image = None
elif image_info.latents_npz is not None:
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
latents = torch.FloatTensor(latents)
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img = self.resize_and_trim(img, image_info.bucket_reso)
else:
if face_cx > 0: # 顔位置情報あり
img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
elif im_h > self.height or im_w > self.width:
assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}"
if im_h > self.height:
p = random.randint(0, im_h - self.height)
img = img[p:p + self.height]
if im_w > self.width:
p = random.randint(0, im_w - self.width)
img = img[:, p:p + self.width]
im_h, im_w = img.shape[0:2]
assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# augmentation
if self.aug is not None:
img = self.aug(image=img)['image']
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
caption = self.process_caption(image_info.caption)
captions.append(caption)
input_ids_list.append(self.get_input_ids(caption))
example = {}
example['loss_weights'] = torch.FloatTensor(loss_weights)
example['input_ids'] = torch.stack(input_ids_list)
if images[0] is not None:
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
else:
images = None
example['images'] = images
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
if self.debug_dataset:
example['image_keys'] = bucket[image_index:image_index + self.batch_size]
example['captions'] = captions
return example
class DreamBoothDataset(BaseDataset):
def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
self.batch_size = batch_size
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.random_crop = random_crop
self.latents_cache = None
self.enable_bucket = False
def read_caption(img_path):
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
tokens = base_name.split("_")
if len(tokens) >= 5:
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding='utf-8') as f:
lines = f.readlines()
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
caption = lines[0].strip()
break
return caption
def load_dreambooth_dir(dir):
if not os.path.isdir(dir):
# print(f"ignore file: {dir}")
return 0, [], []
tokens = os.path.basename(dir).split('_')
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
return 0, [], []
caption_by_folder = '_'.join(tokens[1:])
img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
glob.glob(os.path.join(dir, "*.webp"))
print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files")
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path)
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
return n_repeats, img_paths, captions
print("prepare train images.")
train_dirs = os.listdir(train_data_dir)
num_train_images = 0
for dir in train_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir))
num_train_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info)
print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images
# reg imageは数を数えて学習画像と同じ枚数にする
num_reg_images = 0
if reg_data_dir:
print("prepare reg images.")
reg_infos: list[ImageInfo] = []
reg_dirs = os.listdir(reg_data_dir)
for dir in reg_dirs:
n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir))
num_reg_images += n_repeats * len(img_paths)
for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info)
print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images:
print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
if num_reg_images == 0:
print("no regularization images / 正則化画像が見つかりませんでした")
else:
n = 0
while n < num_train_images:
for info in reg_infos:
self.register_image(info)
n += info.num_repeats
if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない
break
self.num_reg_images = num_reg_images
class FineTuningDataset(BaseDataset):
def __init__(self, metadata, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None:
super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens,
resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset)
self.metadata = metadata
self.train_data_dir = train_data_dir
self.batch_size = batch_size
for image_key, img_md in metadata.items():
# path情報を作る
if os.path.exists(image_key):
abs_path = image_key
else:
# わりといい加減だがいい方法が思いつかん
abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) +
glob.glob(os.path.join(train_data_dir, f"{image_key}.webp")))
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
abs_path = abs_path[0]
caption = img_md.get('caption')
tags = img_md.get('tags')
if caption is None:
caption = tags
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path)
image_info.image_size = img_md.get('train_resolution')
if not self.color_aug:
# if npz exists, use them
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key)
self.register_image(image_info)
self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0
# check existence of all npz files
if not self.color_aug:
npz_any = False
npz_all = True
for image_info in self.image_data.values():
has_npz = image_info.latents_npz is not None
npz_any = npz_any or has_npz
if self.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
npz_all = npz_all and has_npz
if npz_any and not npz_all:
break
if not npz_any:
print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します")
elif not npz_all:
print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します")
for image_info in self.image_data.values():
image_info.latents_npz = image_info.latents_npz_flipped = None
# check min/max bucket size
sizes = set()
for image_info in self.image_data.values():
if image_info.image_size is None:
sizes = None # not calculated
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
if sizes is None:
self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated
else:
self.min_bucket_reso = min(sizes)
self.max_bucket_reso = max(sizes)
def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + '.npz'
if os.path.exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + '_flip.npz'
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
# image_key is relative path
npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz')
npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz')
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
# endregion
# region モジュール入れ替え部
"""
高速化のためのモジュール入れ替え
"""
# FlashAttentionを使うCrossAttention
# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
# constants
EPSILON = 1e-6
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135
class FlashAttentionFunction(Function):
@ staticmethod
@ torch.no_grad()
def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
""" Algorithm 2 in the paper """
device = q.device
dtype = q.dtype
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
o = torch.zeros_like(q)
all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
scale = (q.shape[-1] ** -0.5)
if not exists(mask):
mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
else:
mask = rearrange(mask, 'b n -> b 1 1 n')
mask = mask.split(q_bucket_size, dim=-1)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
mask,
all_row_sums.split(q_bucket_size, dim=-2),
all_row_maxes.split(q_bucket_size, dim=-2),
)
for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if exists(row_mask):
attn_weights.masked_fill_(~row_mask, max_neg_value)
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
device=device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
attn_weights -= block_row_maxes
exp_weights = torch.exp(attn_weights)
if exists(row_mask):
exp_weights.masked_fill_(~row_mask, 0.)
block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
row_maxes.copy_(new_row_maxes)
row_sums.copy_(new_row_sums)
ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
return o
@ staticmethod
@ torch.no_grad()
def backward(ctx, do):
""" Algorithm 4 in the paper """
causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
q, k, v, o, l, m = ctx.saved_tensors
device = q.device
max_neg_value = -torch.finfo(q.dtype).max
qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
row_splits = zip(
q.split(q_bucket_size, dim=-2),
o.split(q_bucket_size, dim=-2),
do.split(q_bucket_size, dim=-2),
mask,
l.split(q_bucket_size, dim=-2),
m.split(q_bucket_size, dim=-2),
dq.split(q_bucket_size, dim=-2)
)
for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
q_start_index = ind * q_bucket_size - qk_len_diff
col_splits = zip(
k.split(k_bucket_size, dim=-2),
v.split(k_bucket_size, dim=-2),
dk.split(k_bucket_size, dim=-2),
dv.split(k_bucket_size, dim=-2),
)
for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
k_start_index = k_ind * k_bucket_size
attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
if causal and q_start_index < (k_start_index + k_bucket_size - 1):
causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
device=device).triu(q_start_index - k_start_index + 1)
attn_weights.masked_fill_(causal_mask, max_neg_value)
exp_attn_weights = torch.exp(attn_weights - mc)
if exists(row_mask):
exp_attn_weights.masked_fill_(~row_mask, 0.)
p = exp_attn_weights / lc
dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
dp = einsum('... i d, ... j d -> ... i j', doc, vc)
D = (doc * oc).sum(dim=-1, keepdims=True)
ds = p * scale * (dp - D)
dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
dqc.add_(dq_chunk)
dkc.add_(dk_chunk)
dvc.add_(dv_chunk)
return dq, dk, dv, None, None, None, None
def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
if mem_eff_attn:
replace_unet_cross_attn_to_memory_efficient()
elif xformers:
replace_unet_cross_attn_to_xformers()
def replace_unet_cross_attn_to_memory_efficient():
print("Replace CrossAttention.forward to use FlashAttention")
flash_func = FlashAttentionFunction
def forward_flash_attn(self, x, context=None, mask=None):
q_bucket_size = 512
k_bucket_size = 1024
h = self.heads
q = self.to_q(x)
context = context if context is not None else x
context = context.to(x.dtype)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
out = rearrange(out, 'b h n d -> b n (h d)')
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_flash_attn
def replace_unet_cross_attn_to_xformers():
print("Replace CrossAttention.forward to use xformers")
try:
import xformers.ops
except ImportError:
raise ImportError("No xformers / xformersがインストールされていないようです")
def forward_xformers(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context = context.to(x.dtype)
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
# out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
# diffusers 0.7.0~
out = self.to_out[0](out)
out = self.to_out[1](out)
return out
diffusers.models.attention.CrossAttention.forward = forward_xformers
# endregion
def collate_fn(examples):
return examples[0]
def train(args):
cache_latents = args.cache_latents
# latentsをキャッシュする場合のオプション設定を確認する
if cache_latents:
assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません"
# その他のオプション設定を確認する
if args.v_parameterization and not args.v2:
print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
if args.v2 and args.clip_skip is not None:
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
use_dreambooth_method = args.in_json is None
# モデル形式のオプション設定を確認する:
load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
# 乱数系列を初期化する
if args.seed is not None:
set_seed(args.seed)
# tokenizerを読み込む
print("prepare tokenizer")
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
if args.max_token_length is not None:
print(f"update token length: {args.max_token_length}")
# 学習データを用意する
assert args.resolution is not None, f"resolution is required / resolution解像度を指定してください"
resolution = tuple([int(r) for r in args.resolution.split(',')])
if len(resolution) == 1:
resolution = (resolution[0], resolution[0])
assert len(resolution) == 2, \
f"resolution must be 'size' or 'width,height' / resolution解像度'サイズ'または'','高さ'で指定してください: {args.resolution}"
if args.face_crop_aug_range is not None:
face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
assert len(
face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
else:
face_crop_aug_range = None
# データセットを準備する
if use_dreambooth_method:
print("Use DreamBooth method.")
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset)
else:
print("Train with captions.")
# メタデータを読み込む
if os.path.exists(args.in_json):
print(f"loading existing metadata: {args.in_json}")
with open(args.in_json, "rt", encoding='utf-8') as f:
metadata = json.load(f)
else:
print(f"no metadata / メタデータファイルがありません: {args.in_json}")
return
if args.color_aug:
print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます")
train_dataset = FineTuningDataset(metadata, args.train_batch_size, args.train_data_dir,
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset)
if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso):
print(f"using bucket info in metadata / メタデータ内のbucket情報を使います")
args.min_bucket_reso = train_dataset.min_bucket_reso
args.max_bucket_reso = train_dataset.max_bucket_reso
args.enable_bucket = True
print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}")
if args.enable_bucket:
assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso)
if args.debug_dataset:
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します")
k = 0
for example in train_dataset:
if example['latents'] is not None:
print("sample has latents from npz file")
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
if example['images'] is not None:
im = example['images'][j]
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == 27 or example['images'] is None:
break
return
if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
return
# acceleratorを準備する
print("prepare accelerator")
if args.logging_dir is None:
log_with = None
logging_dir = None
else:
log_with = "tensorboard"
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision,
log_with=log_with, logging_dir=logging_dir)
# accelerateの互換性問題を解決する
accelerator_0_15 = True
try:
accelerator.unwrap_model("dummy", True)
print("Using accelerator 0.15.0 or above.")
except TypeError:
accelerator_0_15 = False
def unwrap_model(model):
if accelerator_0_15:
return accelerator.unwrap_model(model, True)
return accelerator.unwrap_model(model)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype = torch.float32
if args.mixed_precision == "fp16":
weight_dtype = torch.float16
elif args.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
save_dtype = None
if args.save_precision == "fp16":
save_dtype = torch.float16
elif args.save_precision == "bf16":
save_dtype = torch.bfloat16
elif args.save_precision == "float":
save_dtype = torch.float32
# モデルを読み込む
if load_stable_diffusion_format:
print("load StableDiffusion checkpoint")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
text_encoder = pipe.text_encoder
vae = pipe.vae
unet = pipe.unet
del pipe
# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
print("additional VAE loaded")
# モデルに xformers とか memory efficient attention を組み込む
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset.cache_latents(vae)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# prepare network
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split('=')
net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights is not None:
print("load network weights from:", args.network_weights)
network.load_weights(args.network_weights)
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
network.enable_gradient_checkpointing() # may have no effect
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")
# 8-bit Adamを使う
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
print("use 8-bit Adam optimizer")
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
# dataloaderを準備する
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
print("enable full fp16 training.")
# unet.to(weight_dtype)
# text_encoder.to(weight_dtype)
network.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if train_unet and train_text_encoder:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler)
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler)
elif train_text_encoder:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler)
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler)
unet.requires_grad_(False)
unet.to(accelerator.device, dtype=weight_dtype)
unet.eval()
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder.eval()
network.prepare_grad_etc(text_encoder, unet)
if not cache_latents:
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=weight_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
return org_unscale_grads(optimizer, inv_scale, found_inf, True)
accelerator.scaler._unscale_grads_ = _unscale_grads_replacer
# resumeする
if args.resume is not None:
print(f"resume training from state: {args.resume}")
accelerator.load_state(args.resume)
# epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
print(f" num epochs / epoch数: {num_train_epochs}")
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
if accelerator.is_main_process:
accelerator.init_trackers("network_train")
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
# 指定したステップ数までText Encoderを学習するepoch最初の状態
network.on_epoch_start(text_encoder, unet)
loss_total = 0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network):
with torch.no_grad():
# latentに変換
if batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder):
# Get the text embedding for conditioning
input_ids = batch["input_ids"].to(accelerator.device)
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
if args.clip_skip is None:
encoder_hidden_states = text_encoder(input_ids)[0]
else:
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if args.max_token_length is not None:
if args.v2:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer.model_max_length):
states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
# Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = network.get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss
avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone()
if args.save_every_n_epochs is not None:
if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
print("saving checkpoint.")
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as)
unwrap_model(network).save_weights(ckpt_file, save_dtype)
if args.save_state:
print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
is_main_process = accelerator.is_main_process
if is_main_process:
network = unwrap_model(network)
accelerator.end_training()
if args.save_state:
print("saving last state.")
os.makedirs(args.output_dir, exist_ok=True)
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as)
print(f"save trained model to {ckpt_file}")
network.save_weights(ckpt_file, save_dtype)
print("model saved.")
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
parser.add_argument("--v_parameterization", action='store_true',
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")
parser.add_argument("--shuffle_caption", action="store_true",
help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
parser.add_argument("--keep_tokens", type=int, default=None,
help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す")
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル")
parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
parser.add_argument("--dataset_repeats", type=int, default=1,
help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数")
parser.add_argument("--output_dir", type=str, default=None,
help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式デフォルトはpt")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument("--face_crop_aug_range", type=str, default=None,
help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する2.0,4.0")
parser.add_argument("--random_crop", action="store_true",
help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする顔を中心としたaugmentationを行うときに画風の学習用に指定する")
parser.add_argument("--debug_dataset", action="store_true",
help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
parser.add_argument("--resolution", type=str, default=None,
help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長未指定で75、150または225が指定可")
parser.add_argument("--use_8bit_adam", action="store_true",
help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使うbitsandbytesのインストールが必要")
parser.add_argument("--mem_eff_attn", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument("--vae", type=str, default=None,
help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--cache_latents", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可")
parser.add_argument("--enable_bucket", action="store_true",
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
# parser.add_argument("--stop_text_encoder_training", type=int, default=None,
# help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
parser.add_argument("--mixed_precision", type=str, default="no",
choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数デフォルト0")
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument("--network_train_text_encoder_only", action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
args = parser.parse_args()
train(args)