From 442eb7a29298d196591b9d856d8ff390c2211b3a Mon Sep 17 00:00:00 2001 From: bmaltais Date: Mon, 9 Jan 2023 07:47:07 -0500 Subject: [PATCH] Merge latest kohya code release into GUI repo --- README.md | 22 +- README_dreambooth.md | 204 --- README_finetune.md | 162 -- fine_tune.py | 935 ++--------- fine_tune_README.md | 465 ++++++ fine_tune_README_ja.md | 465 ++++++ gui.cmd | 1 - gui.ps1 | 2 + library/model_util.py | 10 +- library/train_util.py | 1373 +++++++++++++++++ networks/extract_lora_from_models.py | 2 +- train_db.py | 1099 ++----------- train_db_README-ja.md | 296 ++++ train_db_README.md | 295 ++++ train_network.py | 1221 +-------------- ...etwork-ja.md => train_network_README-ja.md | 2 +- ...rain_network.md => train_network_README.md | 59 +- upgrade.ps1 | 3 + 18 files changed, 3232 insertions(+), 3384 deletions(-) delete mode 100644 README_dreambooth.md delete mode 100644 README_finetune.md create mode 100644 fine_tune_README.md create mode 100644 fine_tune_README_ja.md delete mode 100644 gui.cmd create mode 100644 gui.ps1 create mode 100644 library/train_util.py create mode 100644 train_db_README-ja.md create mode 100644 train_db_README.md rename README_train_network-ja.md => train_network_README-ja.md (99%) rename README_train_network.md => train_network_README.md (70%) create mode 100644 upgrade.ps1 diff --git a/README.md b/README.md index e9ba1ad..5f5be6b 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ To install simply unzip the directory and place the cudnn_windows folder in the Run the following command to install: ``` +.\venv\Scripts\activate python .\tools\cudann_1.8_install.py ``` @@ -72,35 +73,36 @@ Once the commands have completed successfully you should be ready to use the new To run the GUI you simply use this command: ``` -gui.cmd +gui.ps1 ``` ## Dreambooth -You can find the dreambooth solution spercific [Dreambooth README](README_dreambooth.md) +You can find the dreambooth solution spercific [Dreambooth README](train_db_README.md) ## Finetune -You can find the finetune solution spercific [Finetune README](README_finetune.md) +You can find the finetune solution spercific [Finetune README](fine_tune_README.md) + +## Train Network + +You can find the train network solution spercific [Train network README](train_network_README.md) ## LoRA -You can create LoRA network by running the dedicated GUI with: +Training a LoRA currently use the `train_network.py` python code. You can create LoRA network by using the all-in-one `gui.cmd` or by running the dedicated LoRA training GUI with: ``` +.\venv\Scripts\activate python lora_gui.py ``` -or via the all in one GUI: - -``` -python kahya_gui.py -``` - Once you have created the LoRA network you can generate images via auto1111 by installing the extension found here: https://github.com/kohya-ss/sd-webui-additional-networks ## Change history +* 2023/01/10 (v20.0): + - Update code base to match latest kohys_ss code upgrade in https://github.com/kohya-ss/sd-scripts * 2023/01/09 (v19.4.3): - Add vae support to dreambooth GUI - Add gradient_checkpointing, gradient_accumulation_steps, mem_eff_attn, shuffle_caption to finetune GUI diff --git a/README_dreambooth.md b/README_dreambooth.md deleted file mode 100644 index 0822aa4..0000000 --- a/README_dreambooth.md +++ /dev/null @@ -1,204 +0,0 @@ -# Kohya_ss Dreambooth - -This repo provide all the required code to run the Dreambooth version found in this note: https://note.com/kohya_ss/n/nee3ed1649fb6 - -## Required Dependencies - -Python 3.10.6 and Git: - -- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe -- git: https://git-scm.com/download/win - -Give unrestricted script access to powershell so venv can work: - -- Open an administrator powershell window -- Type `Set-ExecutionPolicy Unrestricted` and answer A -- Close admin powershell window - -## Installation - -Open a regular Powershell terminal and type the following inside: - -```powershell -git clone https://github.com/bmaltais/kohya_ss.git -cd kohya_ss - -python -m venv --system-site-packages venv -.\venv\Scripts\activate - -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py - -accelerate config - -``` - -Answers to accelerate config: - -```txt -- 0 -- 0 -- NO -- NO -- All -- fp16 -``` - -### Optional: CUDNN 8.6 - -This step is optional but can improve the learning speed for NVidia 4090 owners... - -Due to the filesize I can't host the DLLs needed for CUDNN 8.6 on Github, I strongly advise you download them for a speed boost in sample generation (almost 50% on 4090) you can download them from here: https://b1.thefileditch.ch/mwxKTEtelILoIbMbruuM.zip - -To install simply unzip the directory and place the cudnn_windows folder in the root of the kohya_diffusers_fine_tuning repo. - -Run the following command to install: - -``` -python .\tools\cudann_1.8_install.py -``` - -## Upgrade - -When a new release comes out you can upgrade your repo with the following command: - -```powershell -cd kohya_ss -git pull -.\venv\Scripts\activate -pip install --upgrade -r requirements.txt -``` - -Once the commands have completed successfully you should be ready to use the new version. - -## GUI - -There is now support for GUI based training using gradio. You can start the complete kohya training GUI interface by running: - -```powershell -.\venv\Scripts\activate -.\kohya_gui.cmd -``` - -## CLI - -You can find various examples of how to leverage the `train_db.py` in this folder: https://github.com/bmaltais/kohya_ss/tree/master/examples - -## Support - -Drop by the discord server for support: https://discord.com/channels/1041518562487058594/1041518563242020906 - -## Quickstart screencast - -You can find a screen cast on how to use the GUI at the following location: - -[![Video](https://img.youtube.com/vi/RlvqEKj03WI/maxresdefault.jpg)](https://www.youtube.com/watch?v=RlvqEKj03WI) - -## Folders configuration - -Refer to the note to understand how to create the folde structure. In short it should look like: - -``` - -|- - |- _ -|- - |- _ -``` - -Example for `asd dog` where `asd` is the token word and `dog` is the class. In this example the regularization `dog` class images contained in the folder will be repeated only 1 time and the `asd dog` images will be repeated 20 times: - -``` -my_asd_dog_dreambooth -|- reg_dog - |- 1_dog - `- reg_image_1.png - `- reg_image_2.png - ... - `- reg_image_256.png -|- train_dog - |- 20_asd dog - `- dog1.png - ... - `- dog8.png -``` - -## Support - -Drop by the discord server for support: https://discord.com/channels/1041518562487058594/1041518563242020906 - -## Contributors - -- Lord of the universe - cacoe (twitter: @cac0e) - -## Change history - -* 12/19 (v18.5) update: - - Create model and log folder when running th dreambooth folder creation utility -* 12/19 (v18.4) update: - - Add support for shuffle_caption, save_state, resume, prior_loss_weight under "Advanced Configuration" section - - Fix issue with open/save config not working properly -* 12/19 (v18.3) update: - - fix stop encoder training issue -* 12/19 (v18.2) update: - - Fix file/folder opening behind the browser window - - Add WD14 and BLIP captioning to utilities - - Improve overall GUI layout -* 12/18 (v18.1) update: - - Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements. -* 12/17 (v18) update: - - Save model as option added to train_db_fixed.py - - Save model as option added to GUI - - Retire "Model conversion" parameters that was essentially performing the same function as the new `--save_model_as` parameter -* 12/17 (v17.2) update: - - Adding new dataset balancing utility. -* 12/17 (v17.1) update: - - Adding GUI for kohya_ss called dreambooth_gui.py - - removing support for `--finetuning` as there is now a dedicated python repo for that. `--fine-tuning` is still there behind the scene until kohya_ss remove it in a future code release. - - removing cli examples as I will now focus on the GUI for training. People who prefer cli based training can still do that. -* 12/13 (v17) update: - - Added support for learning to fp16 gradient (experimental function). SD1.x can be trained with 8GB of VRAM. Specify full_fp16 options. -* 12/06 (v16) update: - - Added support for Diffusers 0.10.2 (use code in Diffusers to learn v-parameterization). - - Diffusers also supports safetensors. - - Added support for accelerate 0.15.0. -* 12/05 (v15) update: - - The script has been divided into two parts - - Support for SafeTensors format has been added. Install SafeTensors with `pip install safetensors`. The script will automatically detect the format based on the file extension when loading. Use the `--use_safetensors` option if you want to save the model as safetensor. - - The vae option has been added to load a VAE model separately. - - The log_prefix option has been added to allow adding a custom string to the log directory name before the date and time. -* 11/30 (v13) update: - - fix training text encoder at specified step (`--stop_text_encoder_training=`) that was causing both Unet and text encoder training to stop completely at the specified step rather than continue without text encoding training. -* 11/29 (v12) update: - - stop training text encoder at specified step (`--stop_text_encoder_training=`) - - tqdm smoothing - - updated fine tuning script to support SD2.0 768/v -* 11/27 (v11) update: - - DiffUsers 0.9.0 is required. Update with `pip install --upgrade -r requirements.txt` in the virtual environment. - - The way captions are handled in DreamBooth has changed. When a caption file existed, the file's caption was added to the folder caption until v10, but from v11 it is only the file's caption. Please be careful. - - Fixed a bug where prior_loss_weight was applied to learning images. Sorry for the inconvenience. - - Compatible with Stable Diffusion v2.0. Add the `--v2` option. If you are using `768-v-ema.ckpt` or `stable-diffusion-2` instead of `stable-diffusion-v2-base`, add `--v_parameterization` as well. Learn more about other options. - - Added options related to the learning rate scheduler. - - You can download and use DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training. -* 11/21 (v10): - - Added minimum/maximum resolution specification when using Aspect Ratio Bucketing (min_bucket_reso/max_bucket_reso option). - - Added extension specification for caption files (caption_extention). - - Added support for images with .webp extension. - - Added a function that allows captions to learning images and regularized images. -* 11/18 (v9): - - Added support for Aspect Ratio Bucketing (enable_bucket option). (--enable_bucket) - - Added support for selecting data format (fp16/bf16/float) when saving checkpoint (--save_precision) - - Added support for saving learning state (--save_state, --resume) - - Added support for logging (--logging_dir) -* 11/14 (diffusers_fine_tuning v2): - - script name is now fine_tune.py. - - Added option to learn Text Encoder --train_text_encoder. - - The data format of checkpoint at the time of saving can be specified with the --save_precision option. You can choose float, fp16, and bf16. - - Added a --save_state option to save the learning state (optimizer, etc.) in the middle. It can be resumed with the --resume option. -* 11/9 (v8): supports Diffusers 0.7.2. To upgrade diffusers run `pip install --upgrade diffusers[torch]` -* 11/7 (v7): Text Encoder supports checkpoint files in different storage formats (it is converted at the time of import, so export will be in normal format). Changed the average value of EPOCH loss to output to the screen. Added a function to save epoch and global step in checkpoint in SD format (add values if there is existing data). The reg_data_dir option is enabled during fine tuning (fine tuning while mixing regularized images). Added dataset_repeats option that is valid for fine tuning (specified when the number of teacher images is small and the epoch is extremely short). \ No newline at end of file diff --git a/README_finetune.md b/README_finetune.md deleted file mode 100644 index 962888d..0000000 --- a/README_finetune.md +++ /dev/null @@ -1,162 +0,0 @@ -# Kohya_ss Finetune - -This python utility provide code to run the diffusers fine tuning version found in this note: https://note.com/kohya_ss/n/nbf7ce8d80f29 - -## Required Dependencies - -Python 3.10.6 and Git: - -- Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe -- git: https://git-scm.com/download/win - -Give unrestricted script access to powershell so venv can work: - -- Open an administrator powershell window -- Type `Set-ExecutionPolicy Unrestricted` and answer A -- Close admin powershell window - -## Installation - -Open a regular Powershell terminal and type the following inside: - -```powershell -git clone https://github.com/bmaltais/kohya_diffusers_fine_tuning.git -cd kohya_diffusers_fine_tuning - -python -m venv --system-site-packages venv -.\venv\Scripts\activate - -pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 -pip install --upgrade -r requirements.txt -pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl - -cp .\bitsandbytes_windows\*.dll .\venv\Lib\site-packages\bitsandbytes\ -cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\cextension.py -cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py - -accelerate config - -``` - -Answers to accelerate config: - -```txt -- 0 -- 0 -- NO -- NO -- All -- fp16 -``` - -### Optional: CUDNN 8.6 - -This step is optional but can improve the learning speed for NVidia 4090 owners... - -Due to the filesize I can't host the DLLs needed for CUDNN 8.6 on Github, I strongly advise you download them for a speed boost in sample generation (almost 50% on 4090) you can download them from here: https://b1.thefileditch.ch/mwxKTEtelILoIbMbruuM.zip - -To install simply unzip the directory and place the cudnn_windows folder in the root of the kohya_diffusers_fine_tuning repo. - -Run the following command to install: - -``` -python .\tools\cudann_1.8_install.py -``` - -## Upgrade - -When a new release comes out you can upgrade your repo with the following command: - -```powershell -cd kohya_ss -git pull -.\venv\Scripts\activate -pip install --upgrade -r requirements.txt -``` - -Once the commands have completed successfully you should be ready to use the new version. - -## Folders configuration - -Simply put all the images you will want to train on in a single directory. It does not matter what size or aspect ratio they have. It is your choice. - -## Captions - -Each file need to be accompanied by a caption file describing what the image is about. For example, if you want to train on cute dog pictures you can put `cute dog` as the caption in every file. You can use the `tools\caption.ps1` sample code to help out with that: - -```powershell -$folder = "sample" -$file_pattern="*.*" -$caption_text="cute dog" - -$files = Get-ChildItem "$folder\$file_pattern" -Include *.png, *.jpg, *.webp -File -foreach ($file in $files) { - if (-not(Test-Path -Path $folder\"$($file.BaseName).txt" -PathType Leaf)) { - New-Item -ItemType file -Path $folder -Name "$($file.BaseName).txt" -Value $caption_text - } -} - -You can also use the `Captioning` tool found under the `Utilities` tab in the GUI. -``` - -## GUI - -There is now support for GUI based training using gradio. You can start the complete kohya training GUI interface by running: - -```powershell -.\venv\Scripts\activate -.\kohya_gui.cmd -``` - -## CLI - -You can find various examples of how to leverage the `fine_tune.py` in this folder: https://github.com/bmaltais/kohya_ss/tree/master/examples - -## Support - -Drop by the discord server for support: https://discord.com/channels/1041518562487058594/1041518563242020906 - -## Change history - -* 12/20 (v9.6) update: - - fix issue with config file save and opening -* 12/19 (v9.5) update: - - Fix file/folder dialog opening behind the browser window - - Update GUI layout to be more logical -* 12/18 (v9.4) update: - - Add WD14 tagging to utilities -* 12/18 (v9.3) update: - - Add logging option -* 12/18 (v9.2) update: - - Add BLIP Captioning utility -* 12/18 (v9.1) update: - - Add Stable Diffusion model conversion utility. Make sure to run `pip upgrade -U -r requirements.txt` after updating to this release as this introduce new pip requirements. -* 12/17 (v9) update: - - Save model as option added to fine_tune.py - - Save model as option added to GUI - - Retirement of cli based documentation. Will focus attention to GUI based training -* 12/13 (v8): - - WD14Tagger now works on its own. - - Added support for learning to fp16 up to the gradient. Go to "Building the environment and preparing scripts for Diffusers for more info". -* 12/10 (v7): - - We have added support for Diffusers 0.10.2. - - In addition, we have made other fixes. - - For more information, please see the section on "Building the environment and preparing scripts for Diffusers" in our documentation. -* 12/6 (v6): We have responded to reports that some models experience an error when saving in SafeTensors format. -* 12/5 (v5): - - .safetensors format is now supported. Install SafeTensors as "pip install safetensors". When loading, it is automatically determined by extension. Specify use_safetensors options when saving. - - Added an option to add any string before the date and time log directory name log_prefix. - - Cleaning scripts now work without either captions or tags. -* 11/29 (v4): - - DiffUsers 0.9.0 is required. Update as "pip install -U diffusers[torch]==0.9.0" in the virtual environment, and update the dependent libraries as "pip install --upgrade -r requirements.txt" if other errors occur. - - Compatible with Stable Diffusion v2.0. Add the --v2 option when training (and pre-fetching latents). If you are using 768-v-ema.ckpt or stable-diffusion-2 instead of stable-diffusion-v2-base, add --v_parameterization as well when learning. Learn more about other options. - - The minimum resolution and maximum resolution of the bucket can be specified when pre-fetching latents. - - Corrected the calculation formula for loss (fixed that it was increasing according to the batch size). - - Added options related to the learning rate scheduler. - - So that you can download and learn DiffUsers models directly from Hugging Face. In addition, DiffUsers models can be saved during training. - - Available even if the clean_captions_and_tags.py is only a caption or a tag. - - Other minor fixes such as changing the arguments of the noise scheduler during training. -* 11/23 (v3): - - Added WD14Tagger tagging script. - - A log output function has been added to the fine_tune.py. Also, fixed the double shuffling of data. - - Fixed misspelling of options for each script (caption_extention→caption_extension will work for the time being, even if it remains outdated). diff --git a/fine_tune.py b/fine_tune.py index 39fc15c..1a94870 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -1,263 +1,60 @@ -# v2: select precision for saved checkpoint -# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset) -# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model -# v5: refactor to use model_util, support safetensors, add settings to use Diffusers' xformers, add log prefix -# v6: model_util update -# v7: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0, support full path in metadata -# v8: experimental full fp16 training. -# v9: add keep_tokens and save_model_as option, flip augmentation - -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# License: -# Copyright 2022 Kohya S. @kohya_ss -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# License of included scripts: - -# Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE - -# Memory efficient attention: -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE +# training with captions +# XXX dropped option: hypernetwork training import argparse +import gc import math import os -import random -import json -import importlib -import time from tqdm import tqdm import torch -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import numpy as np -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -LAST_DIFFUSERS_DIR_NAME = "last" -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" +import library.train_util as train_util def collate_fn(examples): return examples[0] -class FineTuningDataset(torch.utils.data.Dataset): - def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, dataset_repeats, debug) -> None: - super().__init__() - - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.debug = debug - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - print("make buckets") - - # 最初に数を数える - self.bucket_resos = set() - for img_md in metadata.values(): - if 'train_resolution' in img_md: - self.bucket_resos.add(tuple(img_md['train_resolution'])) - self.bucket_resos = list(self.bucket_resos) - self.bucket_resos.sort() - print(f"number of buckets: {len(self.bucket_resos)}") - - reso_to_index = {} - for i, reso in enumerate(self.bucket_resos): - reso_to_index[reso] = i - - # bucketに割り当てていく - self.buckets = [[] for _ in range(len(self.bucket_resos))] - n = 1 if dataset_repeats is None else dataset_repeats - images_count = 0 - for image_key, img_md in metadata.items(): - if 'train_resolution' not in img_md: - continue - if not os.path.exists(self.image_key_to_npz_file(image_key)): - continue - - reso = tuple(img_md['train_resolution']) - for _ in range(n): - self.buckets[reso_to_index[reso]].append(image_key) - images_count += n - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - self.images_count = images_count - - def show_buckets(self): - for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def image_key_to_npz_file(self, image_key): - npz_file_norm = os.path.splitext(image_key)[0] + '.npz' - if os.path.exists(npz_file_norm): - if random.random() < .5: - npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz' - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - if random.random() < .5: - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm - - def load_latent(self, image_key): - return np.load(self.image_key_to_npz_file(image_key))['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - input_ids_list = [] - latents_list = [] - captions = [] - for image_key in bucket[image_index:image_index + self.batch_size]: - img_md = self.metadata[image_key] - caption = img_md.get('caption') - tags = img_md.get('tags') - - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}" - - latents = self.load_latent(image_key) - - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() - - captions.append(caption) - - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - - input_ids_list.append(input_ids) - latents_list.append(torch.FloatTensor(latents)) - - example = {} - example['input_ids'] = torch.stack(input_ids_list) - example['latents'] = torch.stack(latents_list) - if self.debug: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example - - -def save_hypernetwork(output_file, hypernetwork): - state_dict = hypernetwork.get_state_dict() - torch.save(state_dict, output_file) - - def train(args): - fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + cache_latents = args.cache_latents - # モデル形式のオプション設定を確認する - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する + tokenizer = train_util.load_tokenizer(args) + + train_dataset = train_util.FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) + train_dataset.make_buckets() + + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + return + if len(train_dataset) == 0: + print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") + return + + # acceleratorを準備する + print("prepare accelerator") + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats if load_stable_diffusion_format: src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_diffusers_model_path = None @@ -272,110 +69,6 @@ def train(args): save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 乱数系列を初期化する - if args.seed is not None: - set_seed(args.seed) - - # メタデータを読み込む - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - - if args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") - - # datasetを用意する - print("prepare dataset") - train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, - tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - args.dataset_repeats, args.debug_dataset) - - print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") - print(f"Total images / 画像数: {train_dataset.images_count}") - - if len(train_dataset) == 0: - print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") - return - - if args.debug_dataset: - train_dataset.show_buckets() - i = 0 - for example in train_dataset: - print(f"image: {example['image_keys']}") - print(f"captions: {example['captions']}") - print(f"latents: {example['latents'].shape}") - print(f"input_ids: {example['input_ids'].shape}") - print(example['input_ids']) - i += 1 - if i >= 8: - break - return - - # acceleratorを準備する - print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - - # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - unet = pipe.unet - vae = pipe.vae - del pipe - vae.to("cpu") # 保存時にしか使わないので、メモリを開けるためCPUに移しておく - # Diffusers版のxformers使用フラグを設定する関数 def set_diffusers_xformers_flag(model, valid): # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう @@ -403,48 +96,44 @@ def train(args): # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある print("Disable Diffusers' xformers") set_diffusers_xformers_flag(unet, False) - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - if not fine_tuning: - # Hypernetwork - print("import hypernetwork module:", args.hypernetwork_module) - hyp_module = importlib.import_module(args.hypernetwork_module) - - hypernetwork = hyp_module.Hypernetwork() - - if args.hypernetwork_weights is not None: - print("load hypernetwork weights from:", args.hypernetwork_weights) - hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') - success = hypernetwork.load_from_state_dict(hyp_sd) - assert success, "hypernetwork weights loading failed." - - print("apply hypernetwork") - hypernetwork.apply_to_diffusers(None, text_encoder, unet) + # 学習を準備する + if cache_latents: + vae.to(accelerator.device, dtype=weight_dtype) + vae.requires_grad_(False) + vae.eval() + with torch.no_grad(): + train_dataset.cache_latents(vae) + vae.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() # 学習を準備する:モデルを適切な状態にする training_models = [] - if fine_tuning: - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) - if args.train_text_encoder: - print("enable text encoder training") - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - text_encoder.eval() + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) else: - unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない - unet.requires_grad_(False) - unet.eval() text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - training_models.append(hypernetwork) + text_encoder.requires_grad_(False) # text encoderは学習しない + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + text_encoder.train() # required for gradient_checkpointing + else: + text_encoder.eval() + + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) for m in training_models: m.requires_grad_(True) @@ -480,40 +169,23 @@ def train(args): lr_scheduler = diffusers.optimization.get_scheduler( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) - # acceleratorがなんかよろしくやってくれるらしい + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - if fine_tuning: - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + # acceleratorがなんかよろしくやってくれるらしい + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: - if args.full_fp16: - unet.to(weight_dtype) - hypernetwork.to(weight_dtype) - - unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, hypernetwork, optimizer, train_dataloader, lr_scheduler) + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer - - # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -527,7 +199,7 @@ def train(args): # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") - print(f" num examples / サンプル数: {train_dataset.images_count}") + print(f" num examples / サンプル数: {train_dataset.num_train_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") @@ -538,17 +210,12 @@ def train(args): progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # v4で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: - accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") + accelerator.init_trackers("finetuning") - # 以下 train_dreambooth.py からほぼコピペ for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") for m in training_models: @@ -557,46 +224,20 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - latents = batch["latents"].to(accelerator.device) - latents = latents * 0.18215 + with torch.no_grad(): + if "latents" in batch and batch["latents"] is not None: + latents = batch["latents"].to(accelerator.device) + else: + # latentに変換 + latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 b_size = latents.shape[0] - # with torch.no_grad(): with torch.set_grad_enabled(args.train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -614,7 +255,6 @@ def train(args): if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -657,403 +297,40 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - - if fine_tuning: - if save_stable_diffusion_format: - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - else: - save_hypernetwork(ckpt_file, unwrap_model(hypernetwork)) - - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: - if fine_tuning: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - else: - hypernetwork = unwrap_model(hypernetwork) + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) accelerator.end_training() if args.save_state: - print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - - if fine_tuning: - if save_stable_diffusion_format: - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - # Create the pipeline using using the trained modules and save it. - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, - src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) - else: - print(f"save trained model to {ckpt_file}") - save_hypernetwork(ckpt_file, hypernetwork) - + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(torch.autograd.function.Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention (not xformers)") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context - - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion - - if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--keep_tokens", type=int, default=None, - help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") - parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") - parser.add_argument("--hypernetwork_module", type=str, default=None, - help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール') - parser.add_argument("--hypernetwork_weights", type=str, default=None, - help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)') - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, - help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--train_batch_size", type=int, default=1, - help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, False, True) + train_util.add_training_arguments(parser, False) + train_util.add_sd_saving_arguments(parser) + parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + help='use xformers by diffusers / Diffusersでxformersを使用する') + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") args = parser.parse_args() - train(args) \ No newline at end of file + train(args) diff --git a/fine_tune_README.md b/fine_tune_README.md new file mode 100644 index 0000000..7ffd05d --- /dev/null +++ b/fine_tune_README.md @@ -0,0 +1,465 @@ +It is a fine tuning that corresponds to NovelAI's proposed learning method, automatic captioning, tagging, Windows + VRAM 12GB (for v1.4/1.5) environment, etc. + +## overview +Fine tuning of U-Net of Stable Diffusion using Diffusers. It corresponds to the following improvements in NovelAI's article (For Aspect Ratio Bucketing, I referred to NovelAI's code, but the final code is all original). + +* Use the output of the penultimate layer instead of the last layer of CLIP (Text Encoder). +* Learning at non-square resolutions (Aspect Ratio Bucketing). +* Extend token length from 75 to 225. +* Captioning with BLIP (automatic creation of captions), automatic tagging with DeepDanbooru or WD14Tagger. +* Also supports Hypernetwork learning. +* Supports Stable Diffusion v2.0 (base and 768/v). +* By acquiring the output of VAE in advance and saving it to disk, we aim to save memory and speed up learning. + +Text Encoder is not trained by default. For fine tuning of the whole model, it seems common to learn only U-Net (NovelAI seems to be the same). Text Encoder can also be learned as an option. + +## Additional features +### Change CLIP output +CLIP (Text Encoder) converts the text into features in order to reflect the prompt in the image. Stable diffusion uses the output of the last layer of CLIP, but you can change it to use the output of the penultimate layer. According to NovelAI, this will reflect prompts more accurately. +It is also possible to use the output of the last layer as is. +*Stable Diffusion 2.0 uses the penultimate layer by default. Do not specify the clip_skip option. + +### Training in non-square resolutions +Stable Diffusion is trained at 512\*512, but also at resolutions such as 256\*1024 and 384\*640. It is expected that this will reduce the cropped portion and learn the relationship between prompts and images more correctly. +The learning resolution is adjusted vertically and horizontally in units of 64 pixels within a range that does not exceed the resolution area (= memory usage) given as a parameter. + +In machine learning, it is common to unify all input sizes, but there are no particular restrictions, and in fact it is okay as long as they are unified within the same batch. NovelAI's bucketing seems to refer to classifying training data in advance for each learning resolution according to the aspect ratio. And by creating a batch with the images in each bucket, the image size of the batch is unified. + +### Extending token length from 75 to 225 +Stable diffusion has a maximum of 75 tokens (77 tokens including the start and end), but we will extend it to 225 tokens. +However, the maximum length that CLIP accepts is 75 tokens, so in the case of 225 tokens, we simply divide it into thirds, call CLIP, and then concatenate the results. + +*I'm not sure if this is the preferred implementation. It seems to be working for now. Especially in 2.0, there is no implementation that can be used as a reference, so I have implemented it independently. + +*Automatic1111's Web UI seems to divide the text with commas in mind, but in my case, it's a simple division. + +## Environmental arrangement + +See the [README](./README-en.md) in this repository. + +## Preparing teacher data + +Prepare the image data you want to learn and put it in any folder. No prior preparation such as resizing is required. +However, for images that are smaller than the training resolution, it is recommended to enlarge them while maintaining the quality using super-resolution. + +It also supports multiple teacher data folders. Preprocessing will be executed for each folder. + +For example, store an image like this: + +![Teacher data folder screenshot](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) + +## Automatic captioning +Skip if you just want to learn tags without captions. + +Also, when preparing captions manually, prepare them in the same directory as the teacher data image, with the same file name, extension .caption, etc. Each file should be a text file with only one line. + +### Captioning with BLIP + +The latest version no longer requires BLIP downloads, weight downloads, and additional virtual environments. Works as-is. + +Run make_captions.py in the finetune folder. + +``` +python finetune\make_captions.py --batch_size +``` + +If the batch size is 8 and the training data is placed in the parent folder train_data, it will be as follows. + +``` +python finetune\make_captions.py --batch_size 8 ..\train_data +``` + +A caption file is created in the same directory as the teacher data image with the same file name and extension .caption. + +Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). +You can specify the maximum length of the caption with the max_length option. Default is 75. It may be longer if the model is trained with a token length of 225. +You can change the caption extension with the caption_extension option. Default is .caption (.txt conflicts with DeepDanbooru described later). + +If there are multiple teacher data folders, execute for each folder. + +Note that the inference is random, so the results will change each time you run it. If you want to fix it, specify a random number seed like "--seed 42" with the --seed option. + +For other options, please refer to the help with --help (there seems to be no documentation for the meaning of the parameters, so you have to look at the source). + +A caption file is generated with the extension .caption by default. + +![Folder where caption is generated](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) + +For example, with captions like: + +![captions and images](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) + +## Tagged by DeepDanbooru +If you do not want to tag the danbooru tag itself, please proceed to "Preprocessing of caption and tag information". + +Tagging is done with DeepDanbooru or WD14Tagger. WD14Tagger seems to be more accurate. If you want to tag with WD14Tagger, skip to the next chapter. + +### Environmental arrangement +Clone DeepDanbooru https://github.com/KichangKim/DeepDanbooru into your working folder, or download the zip and extract it. I unzipped it. +Also, download deepdanbooru-v3-20211112-sgd-e28.zip from Assets of "DeepDanbooru Pretrained Model v3-20211112-sgd-e28" on the DeepDanbooru Releases page https://github.com/KichangKim/DeepDanbooru/releases and extract it to the DeepDanbooru folder. + +Download from below. Click to open Assets and download from there. + +![DeepDanbooru download page](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) + +Make a directory structure like this + +![DeepDanbooru directory structure](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) + +Install the necessary libraries for the Diffusers environment. Go to the DeepDanbooru folder and install it (I think it's actually just adding tensorflow-io). + +``` +pip install -r requirements.txt +``` + +Next, install DeepDanbooru itself. + +``` +pip install . +``` + +This completes the preparation of the environment for tagging. + +### Implementing tagging +Go to DeepDanbooru's folder and run deepdanbooru to tag. + +``` +deepdanbooru evaluate --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +If you put the training data in the parent folder train_data, it will be as follows. + +``` +deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. It is slow because it is processed one by one. + +If there are multiple teacher data folders, execute for each folder. + +It is generated as follows. + +![DeepDanbooru generated files](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) + +A tag is attached like this (great amount of information...). + +![Deep Danbooru tag and image](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) + +## Tagging with WD14Tagger +This procedure uses WD14Tagger instead of DeepDanbooru. + +Use the tagger used in Mr. Automatic1111's WebUI. I referred to the information on this github page (https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger). + +The modules required for the initial environment maintenance have already been installed. Weights are automatically downloaded from Hugging Face. + +### Implementing tagging +Run the script to do the tagging. +``` +python tag_images_by_wd14_tagger.py --batch_size +``` + +If you put the training data in the parent folder train_data, it will be as follows. +``` +python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data +``` + +The model file will be automatically downloaded to the wd14_tagger_model folder on first launch (folder can be changed in options). It will be as follows. + +![downloaded file](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) + +A tag file is created in the same directory as the teacher data image with the same file name and extension .txt. + +![generated tag file](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![tags and images](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +With the thresh option, you can specify the number of confidences of the determined tag to attach the tag. The default is 0.35, same as the WD14Tagger sample. Lower values give more tags, but less accuracy. +Increase or decrease batch_size according to the VRAM capacity of the GPU. Bigger is faster (I think 12GB of VRAM can be a little more). You can change the tag file extension with the caption_extension option. Default is .txt. +You can specify the folder where the model is saved with the model_dir option. +Also, if you specify the force_download option, the model will be re-downloaded even if there is a save destination folder. + +If there are multiple teacher data folders, execute for each folder. + +## Preprocessing caption and tag information + +Combine captions and tags into a single file as metadata for easy processing from scripts. + +### Caption preprocessing + +To put captions into the metadata, run the following in your working folder (if you don't use captions for learning, you don't need to run this) (it's actually a single line, and so on). + +``` +python merge_captions_to_metadata.py +--in_json + +``` + +The metadata file name is an arbitrary name. +If the training data is train_data, there is no metadata file to read, and the metadata file is meta_cap.json, it will be as follows. + +``` +python merge_captions_to_metadata.py train_data meta_cap.json +``` + +You can specify the caption extension with the caption_extension option. + +If there are multiple teacher data folders, please specify the full_path argument (metadata will have full path information). Then run it for each folder. + +``` +python merge_captions_to_metadata.py --full_path + train_data1 meta_cap1.json +python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json + train_data2 meta_cap2.json +``` + +If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. + +__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ + +### Tag preprocessing + +Similarly, tags are also collected in metadata (no need to do this if tags are not used for learning). +``` +python merge_dd_tags_to_metadata.py + --in_json + +``` + +With the same directory structure as above, when reading meta_cap.json and writing to meta_cap_dd.json, it will be as follows. +``` +python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json +``` + +If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. + +``` +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json + train_data1 meta_cap_dd1.json +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json + train_data2 meta_cap_dd2.json +``` + +If in_json is omitted, if there is a write destination metadata file, it will be read from there and overwritten there. + +__*It is safe to rewrite the in_json option and the write destination each time and write to a separate metadata file. __ + +### Cleaning captions and tags +Up to this point, captions and DeepDanbooru tags have been put together in the metadata file. However, captions with automatic captioning are subtle due to spelling variations (*), and tags include underscores and ratings (in the case of DeepDanbooru), so the editor's replacement function etc. You should use it to clean your captions and tags. + +*For example, when learning a girl in an anime picture, there are variations in captions such as girl/girls/woman/women. Also, it may be more appropriate to simply use "girl" for things like "anime girl". + +A script for cleaning is provided, so please edit the contents of the script according to the situation and use it. + +(It is no longer necessary to specify the teacher data folder. All data in the metadata will be cleaned.) + +``` +python clean_captions_and_tags.py +``` + +Please note that --in_json is not included. For example: + +``` +python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json +``` + +Preprocessing of captions and tags is now complete. + +## Get latents in advance + +In order to speed up the learning, we acquire the latent representation of the image in advance and save it to disk. At the same time, bucketing (classifying the training data according to the aspect ratio) is performed. + +In your working folder, type: +``` +python prepare_buckets_latents.py + + + --batch_size + --max_resolution + --mixed_precision +``` + +If the model is model.ckpt, batch size 4, training resolution is 512\*512, precision is no (float32), read metadata from meta_clean.json and write to meta_lat.json: + +``` +python prepare_buckets_latents.py + train_data meta_clean.json meta_lat.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no +``` + +Latents are saved in numpy npz format in the teacher data folder. + +Specify the --v2 option when loading a Stable Diffusion 2.0 model (--v_parameterization is not required). + +You can specify the minimum resolution size with the --min_bucket_reso option and the maximum size with the --max_bucket_reso option. The defaults are 256 and 1024 respectively. For example, specifying a minimum size of 384 will not use resolutions such as 256\*1024 or 320\*768. +If you increase the resolution to something like 768\*768, you should specify something like 1280 for the maximum size. + +If you specify the --flip_aug option, it will perform horizontal flip augmentation (data augmentation). You can artificially double the amount of data, but if you specify it when the data is not left-right symmetrical (for example, character appearance, hairstyle, etc.), learning will not go well. +(This is a simple implementation that acquires the latents for the flipped image and saves the \*\_flip.npz file. No options are required for fline_tune.py. If there is a file with \_flip, Randomly load a file without + +The batch size may be increased a little more even with 12GB of VRAM. +The resolution is a number divisible by 64, and is specified by "width, height". The resolution is directly linked to the memory size during fine tuning. 512,512 seems to be the limit with VRAM 12GB (*). 16GB may be raised to 512,704 or 512,768. Even with 256, 256, etc., it seems to be difficult with 8GB of VRAM (because parameters and optimizers require a certain amount of memory regardless of resolution). + +*There was also a report that learning batch size 1 worked with 12GB VRAM and 640,640. + +The result of bucketing is displayed as follows. + +![bucketing result](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) + +If you have multiple teacher data folders, please specify the full_path argument. Then run it for each folder. +``` +python prepare_buckets_latents.py --full_path + train_data1 meta_clean.json meta_lat1.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +python prepare_buckets_latents.py --full_path + train_data2 meta_lat1.json meta_lat2.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +``` +It is possible to make the read source and write destination the same, but separate is safer. + +__*It is safe to rewrite the argument each time and write it to a separate metadata file. __ + + +## Run training +For example: Below are the settings for saving memory. +``` +accelerate launch --num_cpu_threads_per_process 8 fine_tune.py + --pretrained_model_name_or_path=model.ckpt + --in_json meta_lat.json + --train_data_dir=train_data + --output_dir=fine_tuned + --shuffle_caption + --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=bf16 + --save_every_n_epochs=4 +``` + +It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process of accelerate. + +Specify the model to be trained in pretrained_model_name_or_path (Stable Diffusion checkpoint or Diffusers model). Stable Diffusion checkpoint supports .ckpt and .safetensors (automatically determined by extension). + +Specifies the metadata file when caching latent to in_json. + +Specify the training data folder for train_data_dir and the output destination folder for the trained model for output_dir. + +If shuffle_caption is specified, captions and tags are shuffled and learned in units separated by commas (this is the method used in Waifu Diffusion v1.3). +(You can keep some of the leading tokens fixed without shuffling. See keep_tokens for other options.) + +Specify the batch size in train_batch_size. Specify 1 or 2 for VRAM 12GB. The number that can be specified also changes depending on the resolution. +The actual amount of data used for training is "batch size x number of steps". When increasing the batch size, the number of steps can be decreased accordingly. + +Specify the learning rate in learning_rate. For example Waifu Diffusion v1.3 seems to be 5e-6. +Specify the number of steps in max_train_steps. + +Specify use_8bit_adam to use the 8-bit Adam Optimizer. It saves memory and speeds up, but accuracy may decrease. + +Specifying xformers replaces CrossAttention to save memory and speed up. +* As of 11/9, xformers will cause an error in float32 learning, so please use bf16/fp16 or use memory-saving CrossAttention with mem_eff_attn instead (speed is inferior to xformers). + +Enable intermediate saving of gradients in gradient_checkpointing. It's slower, but uses less memory. + +Specifies whether to use mixed precision with mixed_precision. Specifying "fp16" or "bf16" saves memory, but accuracy is inferior. +"fp16" and "bf16" use almost the same amount of memory, and it is said that bf16 has better learning results (I didn't feel much difference in the range I tried). +If "no" is specified, it will not be used (it will be float32). + +* It seems that an error will occur when reading checkpoints learned with bf16 with Mr. AUTOMATIC1111's Web UI. This seems to be because the data type bfloat16 causes an error in the Web UI model safety checker. Save in fp16 or float32 format with the save_precision option. Or it seems to be good to store it in safetytensors format. + +Specifying save_every_n_epochs will save the model being trained every time that many epochs have passed. + +### Supports Stable Diffusion 2.0 +Specify the --v2 option when using Hugging Face's stable-diffusion-2-base, and specify both --v2 and --v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt please. + +### Increase accuracy and speed when memory is available +First, removing gradient_checkpointing will speed it up. However, the batch size that can be set is reduced, so please set while looking at the balance between accuracy and speed. + +Increasing the batch size increases speed and accuracy. Increase the speed while checking the speed per data within the range where the memory is sufficient (the speed may actually decrease when the memory is at the limit). + +### Change CLIP output used +Specifying 2 for the clip_skip option uses the output of the next-to-last layer. If 1 or option is omitted, the last layer is used. +The learned model should be able to be inferred by Automatic1111's web UI. + +*SD2.0 uses the second layer from the back by default, so please do not specify it when learning SD2.0. + +If the model being trained was originally trained to use the second layer, 2 is a good value. + +If you were using the last layer instead, the entire model would have been trained on that assumption. Therefore, if you train again using the second layer, you may need a certain number of teacher data and longer learning to obtain the desired learning result. + +### Extending Token Length +You can learn by extending the token length by specifying 150 or 225 for max_token_length. +The learned model should be able to be inferred by Automatic1111's web UI. + +As with clip_skip, learning with a length different from the learning state of the model may require a certain amount of teacher data and a longer learning time. + +### Save learning log +Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. + +For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. +Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=fine_tune_style1" for identification. + +To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). +``` +tensorboard --logdir=logs +``` + +### Learning Hypernetworks +It will be explained in another article. + +### Learning with fp16 gradient (experimental feature) +The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). As a result, it seems that the SD1.x 512*512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512*512 size can be learned with a VRAM usage of less than 12GB. + +Specify fp16 in advance in accelerate config and optionally set mixed_precision="fp16" (does not work with bf16). + +To minimize memory usage, use the xformers, use_8bit_adam, gradient_checkpointing options and set train_batch_size to 1. +(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) + +It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). The accuracy will drop considerably, and the probability of learning failure on the way will also increase. The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. + +### Other Options + +#### keep_tokens +If a number is specified, the specified number of tokens (comma-separated strings) from the beginning of the caption are fixed without being shuffled. + +If there are both captions and tags, the prompts during learning will be concatenated like "caption, tag 1, tag 2...", so if you set "--keep_tokens=1", the caption will always be at the beginning during learning. will come. + +#### dataset_repeats +If the number of data sets is extremely small, the epoch will end soon (it will take some time at the epoch break), so please specify a numerical value and multiply the data by some to make the epoch longer. + +#### train_text_encoder +Text Encoder is also a learning target. Slightly increased memory usage. + +In normal fine tuning, the Text Encoder is not targeted for training (probably because U-Net is trained to follow the output of the Text Encoder), but if the number of training data is small, the Text Encoder is trained like DreamBooth. also seems to be valid. + +#### save_precision +The data format when saving checkpoints can be specified from float, fp16, and bf16 (if not specified, it is the same as the data format during learning). It saves disk space, but the model produces different results. Also, if you specify float or fp16, you should be able to read it on Mr. 1111's Web UI. + +*For VAE, the data format of the original checkpoint will remain, so the model size may not be reduced to a little over 2GB even with fp16. + +#### save_model_as +Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. + +When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. + +#### use_safetensors +This option saves checkpoints in safetyensors format. The save format will be the default (same format as loaded). + +#### save_state and resume +The save_state option saves the learning state of the optimizer, etc. in addition to the checkpoint in the folder when saving midway and at the final save. This avoids a decrease in accuracy when learning is resumed after being interrupted (since the optimizer optimizes while having a state, if the state is reset, the optimization must be performed again from the initial state. not). Note that the number of steps is not saved due to Accelerate specifications. + +When starting the script, you can resume by specifying the folder where the state is saved with the resume option. + +Please note that the learning state will be about 5 GB per save, so please be careful of the disk capacity. + +#### gradient_accumulation_steps +Updates the gradient in batches for the specified number of steps. Has a similar effect to increasing the batch size, but consumes slightly more memory. + +*The Accelerate specification does not support multiple learning models, so if you set Text Encoder as the learning target and specify a value of 2 or more for this option, an error may occur. + +#### lr_scheduler / lr_warmup_steps +You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. + +With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. + +#### diffusers_xformers +Uses Diffusers' xformers feature rather than the script's own xformers replacement feature. Hypernetwork learning is no longer possible. \ No newline at end of file diff --git a/fine_tune_README_ja.md b/fine_tune_README_ja.md new file mode 100644 index 0000000..f763490 --- /dev/null +++ b/fine_tune_README_ja.md @@ -0,0 +1,465 @@ +NovelAIの提案した学習手法、自動キャプションニング、タグ付け、Windows+VRAM 12GB(v1.4/1.5の場合)環境等に対応したfine tuningです。 + +## 概要 +Diffusersを用いてStable DiffusionのU-Netのfine tuningを行います。NovelAIの記事にある以下の改善に対応しています(Aspect Ratio BucketingについてはNovelAIのコードを参考にしましたが、最終的なコードはすべてオリジナルです)。 + +* CLIP(Text Encoder)の最後の層ではなく最後から二番目の層の出力を用いる。 +* 正方形以外の解像度での学習(Aspect Ratio Bucketing) 。 +* トークン長を75から225に拡張する。 +* BLIPによるキャプショニング(キャプションの自動作成)、DeepDanbooruまたはWD14Taggerによる自動タグ付けを行う。 +* Hypernetworkの学習にも対応する。 +* Stable Diffusion v2.0(baseおよび768/v)に対応。 +* VAEの出力をあらかじめ取得しディスクに保存しておくことで、学習の省メモリ化、高速化を図る。 + +デフォルトではText Encoderの学習は行いません。モデル全体のfine tuningではU-Netだけを学習するのが一般的なようです(NovelAIもそのようです)。オプション指定でText Encoderも学習対象とできます。 + +## 追加機能について +### CLIPの出力の変更 +プロンプトを画像に反映するため、テキストの特徴量への変換を行うのがCLIP(Text Encoder)です。Stable DiffusionではCLIPの最後の層の出力を用いていますが、それを最後から二番目の層の出力を用いるよう変更できます。NovelAIによると、これによりより正確にプロンプトが反映されるようになるとのことです。 +元のまま、最後の層の出力を用いることも可能です。 +※Stable Diffusion 2.0では最後から二番目の層をデフォルトで使います。clip_skipオプションを指定しないでください。 + +### 正方形以外の解像度での学習 +Stable Diffusionは512\*512で学習されていますが、それに加えて256\*1024や384\*640といった解像度でも学習します。これによりトリミングされる部分が減り、より正しくプロンプトと画像の関係が学習されることが期待されます。 +学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位で縦横に調整、作成されます。 + +機械学習では入力サイズをすべて統一するのが一般的ですが、特に制約があるわけではなく、実際は同一のバッチ内で統一されていれば大丈夫です。NovelAIの言うbucketingは、あらかじめ教師データを、アスペクト比に応じた学習解像度ごとに分類しておくことを指しているようです。そしてバッチを各bucket内の画像で作成することで、バッチの画像サイズを統一します。 + +### トークン長の75から225への拡張 +Stable Diffusionでは最大75トークン(開始・終了を含むと77トークン)ですが、それを225トークンまで拡張します。 +ただしCLIPが受け付ける最大長は75トークンですので、225トークンの場合、単純に三分割してCLIPを呼び出してから結果を連結しています。 + +※これが望ましい実装なのかどうかはいまひとつわかりません。とりあえず動いてはいるようです。特に2.0では何も参考になる実装がないので独自に実装してあります。 + +※Automatic1111氏のWeb UIではカンマを意識して分割、といったこともしているようですが、私の場合はそこまでしておらず単純な分割です。 + +## 環境整備 + +このリポジトリの[README](./README-ja.md)を参照してください。 + +## 教師データの用意 + +学習させたい画像データを用意し、任意のフォルダに入れてください。リサイズ等の事前の準備は必要ありません。 +ただし学習解像度よりもサイズが小さい画像については、超解像などで品質を保ったまま拡大しておくことをお勧めします。 + +複数の教師データフォルダにも対応しています。前処理をそれぞれのフォルダに対して実行する形となります。 + +たとえば以下のように画像を格納します。 + +![教師データフォルダのスクショ](https://user-images.githubusercontent.com/52813779/208907739-8e89d5fa-6ca8-4b60-8927-f484d2a9ae04.png) + +## 自動キャプショニング +キャプションを使わずタグだけで学習する場合はスキップしてください。 + +また手動でキャプションを用意する場合、キャプションは教師データ画像と同じディレクトリに、同じファイル名、拡張子.caption等で用意してください。各ファイルは1行のみのテキストファイルとします。 + +### BLIPによるキャプショニング + +最新版ではBLIPのダウンロード、重みのダウンロード、仮想環境の追加は不要になりました。そのままで動作します。 + +finetuneフォルダ内のmake_captions.pyを実行します。 + +``` +python finetune\make_captions.py --batch_size <バッチサイズ> <教師データフォルダ> +``` + +バッチサイズ8、教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 + +``` +python finetune\make_captions.py --batch_size 8 ..\train_data +``` + +キャプションファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.captionで作成されます。 + +batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。 +max_lengthオプションでキャプションの最大長を指定できます。デフォルトは75です。モデルをトークン長225で学習する場合には長くしても良いかもしれません。 +caption_extensionオプションでキャプションの拡張子を変更できます。デフォルトは.captionです(.txtにすると後述のDeepDanbooruと競合します)。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +なお、推論にランダム性があるため、実行するたびに結果が変わります。固定する場合には--seedオプションで「--seed 42」のように乱数seedを指定してください。 + +その他のオプションは--helpでヘルプをご参照ください(パラメータの意味についてはドキュメントがまとまっていないようで、ソースを見るしかないようです)。 + +デフォルトでは拡張子.captionでキャプションファイルが生成されます。 + +![captionが生成されたフォルダ](https://user-images.githubusercontent.com/52813779/208908845-48a9d36c-f6ee-4dae-af71-9ab462d1459e.png) + +たとえば以下のようなキャプションが付きます。 + +![キャプションと画像](https://user-images.githubusercontent.com/52813779/208908947-af936957-5d73-4339-b6c8-945a52857373.png) + +## DeepDanbooruによるタグ付け +danbooruタグのタグ付け自体を行わない場合は「キャプションとタグ情報の前処理」に進んでください。 + +タグ付けはDeepDanbooruまたはWD14Taggerで行います。WD14Taggerのほうが精度が良いようです。WD14Taggerでタグ付けする場合は、次の章へ進んでください。 + +### 環境整備 +DeepDanbooru https://github.com/KichangKim/DeepDanbooru を作業フォルダにcloneしてくるか、zipをダウンロードして展開します。私はzipで展開しました。 +またDeepDanbooruのReleasesのページ https://github.com/KichangKim/DeepDanbooru/releases の「DeepDanbooru Pretrained Model v3-20211112-sgd-e28」のAssetsから、deepdanbooru-v3-20211112-sgd-e28.zipをダウンロードしてきてDeepDanbooruのフォルダに展開します。 + +以下からダウンロードします。Assetsをクリックして開き、そこからダウンロードします。 + +![DeepDanbooruダウンロードページ](https://user-images.githubusercontent.com/52813779/208909417-10e597df-7085-41ee-bd06-3e856a1339df.png) + +以下のようなこういうディレクトリ構造にしてください + +![DeepDanbooruのディレクトリ構造](https://user-images.githubusercontent.com/52813779/208909486-38935d8b-8dc6-43f1-84d3-fef99bc471aa.png) + +Diffusersの環境に必要なライブラリをインストールします。DeepDanbooruのフォルダに移動してインストールします(実質的にはtensorflow-ioが追加されるだけだと思います)。 + +``` +pip install -r requirements.txt +``` + +続いてDeepDanbooru自体をインストールします。 + +``` +pip install . +``` + +以上でタグ付けの環境整備は完了です。 + +### タグ付けの実施 +DeepDanbooruのフォルダに移動し、deepdanbooruを実行してタグ付けを行います。 + +``` +deepdanbooru evaluate <教師データフォルダ> --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 + +``` +deepdanbooru evaluate ../train_data --project-path deepdanbooru-v3-20211112-sgd-e28 --allow-folder --save-txt +``` + +タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。1件ずつ処理されるためわりと遅いです。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +以下のように生成されます。 + +![DeepDanbooruの生成ファイル](https://user-images.githubusercontent.com/52813779/208909855-d21b9c98-f2d3-4283-8238-5b0e5aad6691.png) + +こんな感じにタグが付きます(すごい情報量……)。 + +![DeepDanbooruタグと画像](https://user-images.githubusercontent.com/52813779/208909908-a7920174-266e-48d5-aaef-940aba709519.png) + +## WD14Taggerによるタグ付け +DeepDanbooruの代わりにWD14Taggerを用いる手順です。 + +Automatic1111氏のWebUIで使用しているtaggerを利用します。こちらのgithubページ(https://github.com/toriato/stable-diffusion-webui-wd14-tagger#mrsmilingwolfs-model-aka-waifu-diffusion-14-tagger )の情報を参考にさせていただきました。 + +最初の環境整備で必要なモジュールはインストール済みです。また重みはHugging Faceから自動的にダウンロードしてきます。 + +### タグ付けの実施 +スクリプトを実行してタグ付けを行います。 +``` +python tag_images_by_wd14_tagger.py --batch_size <バッチサイズ> <教師データフォルダ> +``` + +教師データを親フォルダのtrain_dataに置いた場合、以下のようになります。 +``` +python tag_images_by_wd14_tagger.py --batch_size 4 ..\train_data +``` + +初回起動時にはモデルファイルがwd14_tagger_modelフォルダに自動的にダウンロードされます(フォルダはオプションで変えられます)。以下のようになります。 + +![ダウンロードされたファイル](https://user-images.githubusercontent.com/52813779/208910447-f7eb0582-90d6-49d3-a666-2b508c7d1842.png) + +タグファイルが教師データ画像と同じディレクトリに、同じファイル名、拡張子.txtで作成されます。 + +![生成されたタグファイル](https://user-images.githubusercontent.com/52813779/208910534-ea514373-1185-4b7d-9ae3-61eb50bc294e.png) + +![タグと画像](https://user-images.githubusercontent.com/52813779/208910599-29070c15-7639-474f-b3e4-06bd5a3df29e.png) + +threshオプションで、判定されたタグのconfidence(確信度)がいくつ以上でタグをつけるかが指定できます。デフォルトはWD14Taggerのサンプルと同じ0.35です。値を下げるとより多くのタグが付与されますが、精度は下がります。 +batch_sizeはGPUのVRAM容量に応じて増減してください。大きいほうが速くなります(VRAM 12GBでももう少し増やせると思います)。caption_extensionオプションでタグファイルの拡張子を変更できます。デフォルトは.txtです。 +model_dirオプションでモデルの保存先フォルダを指定できます。 +またforce_downloadオプションを指定すると保存先フォルダがあってもモデルを再ダウンロードします。 + +複数の教師データフォルダがある場合には、それぞれのフォルダに対して実行してください。 + +## キャプションとタグ情報の前処理 + +スクリプトから処理しやすいようにキャプションとタグをメタデータとしてひとつのファイルにまとめます。 + +### キャプションの前処理 + +キャプションをメタデータに入れるには、作業フォルダ内で以下を実行してください(キャプションを学習に使わない場合は実行不要です)(実際は1行で記述します、以下同様)。 + +``` +python merge_captions_to_metadata.py <教師データフォルダ> +  --in_json <読み込むメタデータファイル名> + <メタデータファイル名> +``` + +メタデータファイル名は任意の名前です。 +教師データがtrain_data、読み込むメタデータファイルなし、メタデータファイルがmeta_cap.jsonの場合、以下のようになります。 + +``` +python merge_captions_to_metadata.py train_data meta_cap.json +``` + +caption_extensionオプションでキャプションの拡張子を指定できます。 + +複数の教師データフォルダがある場合には、full_path引数を指定してください(メタデータにフルパスで情報を持つようになります)。そして、それぞれのフォルダに対して実行してください。 + +``` +python merge_captions_to_metadata.py --full_path + train_data1 meta_cap1.json +python merge_captions_to_metadata.py --full_path --in_json meta_cap1.json + train_data2 meta_cap2.json +``` + +in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 + +__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ + +### タグの前処理 + +同様にタグもメタデータにまとめます(タグを学習に使わない場合は実行不要です)。 +``` +python merge_dd_tags_to_metadata.py <教師データフォルダ> + --in_json <読み込むメタデータファイル名> + <書き込むメタデータファイル名> +``` + +先と同じディレクトリ構成で、meta_cap.jsonを読み、meta_cap_dd.jsonに書きだす場合、以下となります。 +``` +python merge_dd_tags_to_metadata.py train_data --in_json meta_cap.json meta_cap_dd.json +``` + +複数の教師データフォルダがある場合には、full_path引数を指定してください。そして、それぞれのフォルダに対して実行してください。 + +``` +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap2.json + train_data1 meta_cap_dd1.json +python merge_dd_tags_to_metadata.py --full_path --in_json meta_cap_dd1.json + train_data2 meta_cap_dd2.json +``` + +in_jsonを省略すると書き込み先メタデータファイルがあるとそこから読み込み、そこに上書きします。 + +__※in_jsonオプションと書き込み先を都度書き換えて、別のメタデータファイルへ書き出すようにすると安全です。__ + +### キャプションとタグのクリーニング +ここまででメタデータファイルにキャプションとDeepDanbooruのタグがまとめられています。ただ自動キャプショニングにしたキャプションは表記ゆれなどがあり微妙(※)ですし、タグにはアンダースコアが含まれていたりratingが付いていたりしますので(DeepDanbooruの場合)、エディタの置換機能などを用いてキャプションとタグのクリーニングをしたほうがいいでしょう。 + +※たとえばアニメ絵の少女を学習する場合、キャプションにはgirl/girls/woman/womenなどのばらつきがあります。また「anime girl」なども単に「girl」としたほうが適切かもしれません。 + +クリーニング用のスクリプトが用意してありますので、スクリプトの内容を状況に応じて編集してお使いください。 + +(教師データフォルダの指定は不要になりました。メタデータ内の全データをクリーニングします。) + +``` +python clean_captions_and_tags.py <読み込むメタデータファイル名> <書き込むメタデータファイル名> +``` + +--in_jsonは付きませんのでご注意ください。たとえば次のようになります。 + +``` +python clean_captions_and_tags.py meta_cap_dd.json meta_clean.json +``` + +以上でキャプションとタグの前処理は完了です。 + +## latentsの事前取得 + +学習を高速に進めるためあらかじめ画像の潜在表現を取得しディスクに保存しておきます。あわせてbucketing(教師データをアスペクト比に応じて分類する)を行います。 + +作業フォルダで以下のように入力してください。 +``` +python prepare_buckets_latents.py <教師データフォルダ> + <読み込むメタデータファイル名> <書き込むメタデータファイル名> + + --batch_size <バッチサイズ> + --max_resolution <解像度 幅,高さ> + --mixed_precision <精度> +``` + +モデルがmodel.ckpt、バッチサイズ4、学習解像度は512\*512、精度no(float32)で、meta_clean.jsonからメタデータを読み込み、meta_lat.jsonに書き込む場合、以下のようになります。 + +``` +python prepare_buckets_latents.py + train_data meta_clean.json meta_lat.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no +``` + +教師データフォルダにnumpyのnpz形式でlatentsが保存されます。 + +Stable Diffusion 2.0のモデルを読み込む場合は--v2オプションを指定してください(--v_parameterizationは不要です)。 + +解像度の最小サイズを--min_bucket_resoオプションで、最大サイズを--max_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。たとえば最小サイズに384を指定すると、256\*1024や320\*768などの解像度は使わなくなります。 +解像度を768\*768のように大きくした場合、最大サイズに1280などを指定すると良いでしょう。 + +--flip_augオプションを指定すると左右反転のaugmentation(データ拡張)を行います。疑似的にデータ量を二倍に増やすことができますが、データが左右対称でない場合に指定すると(例えばキャラクタの外見、髪型など)学習がうまく行かなくなります。 +(反転した画像についてもlatentsを取得し、\*\_flip.npzファイルを保存する単純な実装です。fline_tune.pyには特にオプション指定は必要ありません。\_flip付きのファイルがある場合、flip付き・なしのファイルを、ランダムに読み込みます。) + +バッチサイズはVRAM 12GBでももう少し増やせるかもしれません。 +解像度は64で割り切れる数字で、"幅,高さ"で指定します。解像度はfine tuning時のメモリサイズに直結します。VRAM 12GBでは512,512が限界と思われます(※)。16GBなら512,704や512,768まで上げられるかもしれません。なお256,256等にしてもVRAM 8GBでは厳しいようです(パラメータやoptimizerなどは解像度に関係せず一定のメモリが必要なため)。 + +※batch size 1の学習で12GB VRAM、640,640で動いたとの報告もありました。 + +以下のようにbucketingの結果が表示されます。 + +![bucketingの結果](https://user-images.githubusercontent.com/52813779/208911419-71c00fbb-2ce6-49d5-89b5-b78d7715e441.png) + +複数の教師データフォルダがある場合には、full_path引数を指定してください。そして、それぞれのフォルダに対して実行してください。 +``` +python prepare_buckets_latents.py --full_path + train_data1 meta_clean.json meta_lat1.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +python prepare_buckets_latents.py --full_path + train_data2 meta_lat1.json meta_lat2.json model.ckpt + --batch_size 4 --max_resolution 512,512 --mixed_precision no + +``` +読み込み元と書き込み先を同じにすることも可能ですが別々の方が安全です。 + +__※引数を都度書き換えて、別のメタデータファイルに書き込むと安全です。__ + + +## 学習の実行 +たとえば以下のように実行します。以下は省メモリ化のための設定です。 +``` +accelerate launch --num_cpu_threads_per_process 8 fine_tune.py + --pretrained_model_name_or_path=model.ckpt + --in_json meta_lat.json + --train_data_dir=train_data + --output_dir=fine_tuned + --shuffle_caption + --train_batch_size=1 --learning_rate=5e-6 --max_train_steps=10000 + --use_8bit_adam --xformers --gradient_checkpointing + --mixed_precision=bf16 + --save_every_n_epochs=4 +``` + +accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。 + +pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。 + +in_jsonにlatentをキャッシュしたときのメタデータファイルを指定します。 + +train_data_dirに教師データのフォルダを、output_dirに学習後のモデルの出力先フォルダを指定します。 + +shuffle_captionを指定すると、キャプション、タグをカンマ区切りされた単位でシャッフルして学習します(Waifu Diffusion v1.3で行っている手法です)。 +(先頭のトークンのいくつかをシャッフルせずに固定できます。その他のオプションのkeep_tokensをご覧ください。) + +train_batch_sizeにバッチサイズを指定します。VRAM 12GBでは1か2程度を指定してください。解像度によっても指定可能な数は変わってきます。 +学習に使用される実際のデータ量は「バッチサイズ×ステップ数」です。バッチサイズを増やした時には、それに応じてステップ数を下げることが可能です。 + +learning_rateに学習率を指定します。たとえばWaifu Diffusion v1.3は5e-6のようです。 +max_train_stepsにステップ数を指定します。 + +use_8bit_adamを指定すると8-bit Adam Optimizerを使用します。省メモリ化、高速化されますが精度は下がる可能性があります。 + +xformersを指定するとCrossAttentionを置換して省メモリ化、高速化します。 +※11/9時点ではfloat32の学習ではxformersがエラーになるため、bf16/fp16を使うか、代わりにmem_eff_attnを指定して省メモリ版CrossAttentionを使ってください(速度はxformersに劣ります)。 + +gradient_checkpointingで勾配の途中保存を有効にします。速度は遅くなりますが使用メモリ量が減ります。 + +mixed_precisionで混合精度を使うか否かを指定します。"fp16"または"bf16"を指定すると省メモリになりますが精度は劣ります。 +"fp16"と"bf16"は使用メモリ量はほぼ同じで、bf16の方が学習結果は良くなるとの話もあります(試した範囲ではあまり違いは感じられませんでした)。 +"no"を指定すると使用しません(float32になります)。 + +※bf16で学習したcheckpointをAUTOMATIC1111氏のWeb UIで読み込むとエラーになるようです。これはデータ型のbfloat16がWeb UIのモデルsafety checkerでエラーとなるためのようです。save_precisionオプションを指定してfp16またはfloat32形式で保存してください。またはsafetensors形式で保管しても良さそうです。 + +save_every_n_epochsを指定するとそのエポックだけ経過するたびに学習中のモデルを保存します。 + +### Stable Diffusion 2.0対応 +Hugging Faceのstable-diffusion-2-baseを使う場合は--v2オプションを、stable-diffusion-2または768-v-ema.ckptを使う場合は--v2と--v_parameterizationの両方のオプションを指定してください。 + +### メモリに余裕がある場合に精度や速度を上げる +まずgradient_checkpointingを外すと速度が上がります。ただし設定できるバッチサイズが減りますので、精度と速度のバランスを見ながら設定してください。 + +バッチサイズを増やすと速度、精度が上がります。メモリが足りる範囲で、1データ当たりの速度を確認しながら増やしてください(メモリがぎりぎりになるとかえって速度が落ちることがあります)。 + +### 使用するCLIP出力の変更 +clip_skipオプションに2を指定すると、後ろから二番目の層の出力を用います。1またはオプション省略時は最後の層を用います。 +学習したモデルはAutomatic1111氏のWeb UIで推論できるはずです。 + +※SD2.0はデフォルトで後ろから二番目の層を使うため、SD2.0の学習では指定しないでください。 + +学習対象のモデルがもともと二番目の層を使うように学習されている場合は、2を指定するとよいでしょう。 + +そうではなく最後の層を使用していた場合はモデル全体がそれを前提に学習されています。そのため改めて二番目の層を使用して学習すると、望ましい学習結果を得るにはある程度の枚数の教師データ、長めの学習が必要になるかもしれません。 + +### トークン長の拡張 +max_token_lengthに150または225を指定することでトークン長を拡張して学習できます。 +学習したモデルはAutomatic1111氏のWeb UIで推論できるはずです。 + +clip_skipと同様に、モデルの学習状態と異なる長さで学習するには、ある程度の教師データ枚数、長めの学習時間が必要になると思われます。 + +### 学習ログの保存 +logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 + +たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。 +また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=fine_tune_style1」などとして識別用にお使いください。 + +TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します(tensorboardはDiffusersのインストール時にあわせてインストールされると思いますが、もし入っていないならpip install tensorboardで入れてください)。 +``` +tensorboard --logdir=logs +``` + +### Hypernetworkの学習 +別の記事で解説予定です。 + +### 勾配をfp16とした学習(実験的機能) +full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。これによりSD1.xの512*512サイズでは8GB未満、SD2.xの512*512サイズで12GB未満のVRAM使用量で学習できるようです。 + +あらかじめaccelerate configでfp16を指定し、オプションでmixed_precision="fp16"としてください(bf16では動作しません)。 + +メモリ使用量を最小化するためには、xformers、use_8bit_adam、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。 +(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。) + +PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。 + +### その他のオプション + +#### keep_tokens +数値を指定するとキャプションの先頭から、指定した数だけのトークン(カンマ区切りの文字列)をシャッフルせず固定します。 + +キャプションとタグが両方ある場合、学習時のプロンプトは「キャプション,タグ1,タグ2……」のように連結されますので、「--keep_tokens=1」とすれば、学習時にキャプションが必ず先頭に来るようになります。 + +#### dataset_repeats +データセットの枚数が極端に少ない場合、epochがすぐに終わってしまうため(epochの区切りで少し時間が掛かります)、数値を指定してデータを何倍かしてepochを長めにしてください。 + +#### train_text_encoder +Text Encoderも学習対象とします。メモリ使用量が若干増加します。 + +通常のfine tuningではText Encoderは学習対象としませんが(恐らくText Encoderの出力に従うようにU-Netを学習するため)、学習データ数が少ない場合には、DreamBoothのようにText Encoder側に学習させるのも有効的なようです。 + +#### save_precision +checkpoint保存時のデータ形式をfloat、fp16、bf16から指定できます(未指定時は学習中のデータ形式と同じ)。ディスク容量が節約できますがモデルによる生成結果は変わってきます。またfloatやfp16を指定すると、1111氏のWeb UIでも読めるようになるはずです。 + +※VAEについては元のcheckpointのデータ形式のままになりますので、fp16でもモデルサイズが2GB強まで小さくならない場合があります。 + +#### save_model_as +モデルの保存形式を指定します。ckpt、safetensors、diffusers、diffusers_safetensorsのいずれかを指定してください。 + +Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。 + +#### use_safetensors +このオプションを指定するとsafetensors形式でcheckpointを保存します。保存形式はデフォルト(読み込んだ形式と同じ)になります。 + +#### save_stateとresume +save_stateオプションで、途中保存時および最終保存時に、checkpointに加えてoptimizer等の学習状態をフォルダに保存します。これにより中断してから学習再開したときの精度低下が避けられます(optimizerは状態を持ちながら最適化をしていくため、その状態がリセットされると再び初期状態から最適化を行わなくてはなりません)。なお、Accelerateの仕様でステップ数は保存されません。 + +スクリプト起動時、resumeオプションで状態の保存されたフォルダを指定すると再開できます。 + +学習状態は一回の保存あたり5GB程度になりますのでディスク容量にご注意ください。 + +#### gradient_accumulation_steps +指定したステップ数だけまとめて勾配を更新します。バッチサイズを増やすのと同様の効果がありますが、メモリを若干消費します。 + +※Accelerateの仕様で学習モデルが複数の場合には対応していないとのことですので、Text Encoderを学習対象にして、このオプションに2以上の値を指定するとエラーになるかもしれません。 + +#### lr_scheduler / lr_warmup_steps +lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。 + +lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。詳細については各自お調べください。 + +#### diffusers_xformers +スクリプト独自のxformers置換機能ではなくDiffusersのxformers機能を利用します。Hypernetworkの学習はできなくなります。 diff --git a/gui.cmd b/gui.cmd deleted file mode 100644 index 379ff8d..0000000 --- a/gui.cmd +++ /dev/null @@ -1 +0,0 @@ -.\venv\Scripts\python.exe kohya_gui.py \ No newline at end of file diff --git a/gui.ps1 b/gui.ps1 new file mode 100644 index 0000000..4f799a1 --- /dev/null +++ b/gui.ps1 @@ -0,0 +1,2 @@ +.\venv\Scripts\activate +python.exe kohya_gui.py \ No newline at end of file diff --git a/library/model_util.py b/library/model_util.py index ad2b427..bc824a1 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -1133,14 +1133,6 @@ def load_vae(vae_id, dtype): return vae -def get_epoch_ckpt_name(use_safetensors, epoch): - return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") - - -def get_last_ckpt_name(use_safetensors): - return f"last" + (".safetensors" if use_safetensors else ".ckpt") - - # endregion @@ -1187,4 +1179,4 @@ if __name__ == '__main__': for ar in aspect_ratios: if ar in ars: print("error! duplicate ar:", ar) - ars.add(ar) \ No newline at end of file + ars.add(ar) diff --git a/library/train_util.py b/library/train_util.py new file mode 100644 index 0000000..bad954c --- /dev/null +++ b/library/train_util.py @@ -0,0 +1,1373 @@ +# common functions for training + +import argparse +import json +import shutil +import time +from typing import NamedTuple +from accelerate import Accelerator +from torch.autograd.function import Function +import glob +import math +import os +import random + +from tqdm import tqdm +import torch +from torchvision import transforms +from transformers import CLIPTokenizer +import diffusers +from diffusers import DDPMScheduler, StableDiffusionPipeline +import albumentations as albu +import numpy as np +from PIL import Image +import cv2 +from einops import rearrange +from torch import einsum + +import library.model_util as model_util + +# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ + +# checkpointファイル名 +EPOCH_STATE_NAME = "{}-{:06d}-state" +EPOCH_FILE_NAME = "{}-{:06d}" +EPOCH_DIFFUSERS_DIR_NAME = "{}-{:06d}" +LAST_STATE_NAME = "{}-state" +DEFAULT_EPOCH_NAME = "epoch" +DEFAULT_LAST_OUTPUT_NAME = "last" + +# region dataset + +IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp"] + + +class ImageInfo(): + def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: + self.image_key: str = image_key + self.num_repeats: int = num_repeats + self.caption: str = caption + self.is_reg: bool = is_reg + self.absolute_path: str = absolute_path + self.image_size: tuple[int, int] = None + self.bucket_reso: tuple[int, int] = None + self.latents: torch.Tensor = None + self.latents_flipped: torch.Tensor = None + self.latents_npz: str = None + self.latents_npz_flipped: str = None + + +class BucketBatchIndex(NamedTuple): + bucket_index: int + batch_index: int + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, random_crop, debug_dataset: bool) -> None: + super().__init__() + self.tokenizer: CLIPTokenizer = tokenizer + self.max_token_length = max_token_length + self.shuffle_caption = shuffle_caption + self.shuffle_keep_tokens = shuffle_keep_tokens + # width/height is used when enable_bucket==False + self.width, self.height = (None, None) if resolution is None else resolution + self.face_crop_aug_range = face_crop_aug_range + self.flip_aug = flip_aug + self.color_aug = color_aug + self.debug_dataset = debug_dataset + self.random_crop = random_crop + self.token_padding_disabled = False + + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 + + # augmentation + flip_p = 0.5 if flip_aug else 0.0 + if color_aug: + # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る + self.aug = albu.Compose([ + albu.OneOf([ + albu.HueSaturationValue(8, 0, 0, p=.5), + albu.RandomGamma((95, 105), p=.5), + ], p=.33), + albu.HorizontalFlip(p=flip_p) + ], p=1.) + elif flip_aug: + self.aug = albu.Compose([ + albu.HorizontalFlip(p=flip_p) + ], p=1.) + else: + self.aug = None + + self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) + + self.image_data: dict[str, ImageInfo] = {} + + def disable_token_padding(self): + self.token_padding_disabled = True + + def process_caption(self, caption): + if self.shuffle_caption: + tokens = caption.strip().split(",") + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + else: + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = keep_tokens + tokens + caption = ",".join(tokens).strip() + return caption + + def get_input_ids(self, caption): + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids + + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) + + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id + + iids_list.append(ids_chunk) + + input_ids = torch.stack(iids_list) # 3,77 + return input_ids + + def register_image(self, info: ImageInfo): + self.image_data[info.image_key] = info + + def make_buckets(self): + ''' + bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) + min_size and max_size are ignored when enable_bucket is False + ''' + print("loading image sizes.") + for info in tqdm(self.image_data.values()): + if info.image_size is None: + info.image_size = self.get_image_size(info.absolute_path) + + if self.enable_bucket: + print("make buckets") + else: + print("prepare dataset") + + bucket_resos = self.bucket_resos + bucket_aspect_ratios = np.array(self.bucket_aspect_ratios) + + # bucketを作成する + if self.enable_bucket: + img_ar_errors = [] + for image_info in self.image_data.values(): + # bucketを決める + image_width, image_height = image_info.image_size + aspect_ratio = image_width / image_height + ar_errors = bucket_aspect_ratios - aspect_ratio + + bucket_id = np.abs(ar_errors).argmin() + image_info.bucket_reso = bucket_resos[bucket_id] + + ar_error = ar_errors[bucket_id] + img_ar_errors.append(ar_error) + else: + for image_info in self.image_data.values(): + image_info.bucket_reso = bucket_resos[0] # bucket_resos contains (width, height) only + + # 画像をbucketに分割する + self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] + reso_to_index = {} + for i, reso in enumerate(bucket_resos): + reso_to_index[reso] = i + + for image_info in self.image_data.values(): + bucket_index = reso_to_index[image_info.bucket_reso] + for _ in range(image_info.num_repeats): + self.buckets[bucket_index].append(image_info.image_key) + + if self.enable_bucket: + print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)") + for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") + img_ar_errors = np.array(img_ar_errors) + print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") + + # 参照用indexを作る + self.buckets_indices: list(BucketBatchIndex) = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) + + self.shuffle_buckets() + self._length = len(self.buckets_indices) + + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) + + def load_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + + def resize_and_trim(self, image, reso): + image_height, image_width = image.shape[0:2] + ar_img = image_width / image_height + ar_reso = reso[0] / reso[1] + if ar_img > ar_reso: # 横が長い→縦を合わせる + scale = reso[1] / image_height + else: + scale = reso[0] / image_width + resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) + + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if resized_size[0] > reso[0]: + trim_size = resized_size[0] - reso[0] + image = image[:, trim_size//2:trim_size//2 + reso[0]] + elif resized_size[1] > reso[1]: + trim_size = resized_size[1] - reso[1] + image = image[trim_size//2:trim_size//2 + reso[1]] + assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ + f"internal error, illegal trimmed size: {image.shape}, {reso}" + return image + + def cache_latents(self, vae): + print("caching latents.") + for info in tqdm(self.image_data.values()): + if info.latents_npz is not None: + info.latents = self.load_latents_from_npz(info, False) + info.latents = torch.FloatTensor(info.latents) + info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + if info.latents_flipped is not None: + info.latents_flipped = torch.FloatTensor(info.latents_flipped) + continue + + image = self.load_image(info.absolute_path) + image = self.resize_and_trim(image, info.bucket_reso) + + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + if self.flip_aug: + image = image[:, ::-1].copy() # cannot convert to Tensor without copy + img_tensor = self.image_transforms(image) + img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) + info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + + def get_image_size(self, image_path): + image = Image.open(image_path) + return image.size + + def load_image_with_face_info(self, image_path: str): + img = self.load_image(image_path) + + face_cx = face_cy = face_w = face_h = 0 + if self.face_crop_aug_range is not None: + tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') + if len(tokens) >= 5: + face_cx = int(tokens[-4]) + face_cy = int(tokens[-3]) + face_w = int(tokens[-2]) + face_h = int(tokens[-1]) + + return img, face_cx, face_cy, face_w, face_h + + # いい感じに切り出す + def crop_target(self, image, face_cx, face_cy, face_w, face_h): + height, width = image.shape[0:2] + if height == self.height and width == self.width: + return image + + # 画像サイズはsizeより大きいのでリサイズする + face_size = max(face_w, face_h) + min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) + min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ + max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ + if min_scale >= max_scale: # range指定がmin==max + scale = min_scale + else: + scale = random.uniform(min_scale, max_scale) + + nh = int(height * scale + .5) + nw = int(width * scale + .5) + assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" + image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) + face_cx = int(face_cx * scale + .5) + face_cy = int(face_cy * scale + .5) + height, width = nh, nw + + # 顔を中心として448*640とかへ切り出す + for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): + p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 + + if self.random_crop: + # 背景も含めるために顔を中心に置く確率を高めつつずらす + range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう + p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 + else: + # range指定があるときのみ、すこしだけランダムに(わりと適当) + if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: + if face_size > self.size // 10 and face_size >= 40: + p1 = p1 + random.randint(-face_size // 20, +face_size // 20) + + p1 = max(0, min(p1, length - target_size)) + + if axis == 0: + image = image[p1:p1 + target_size, :] + else: + image = image[:, p1:p1 + target_size] + + return image + + def load_latents_from_npz(self, image_info: ImageInfo, flipped): + npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz + if npz_file is None: + return None + return np.load(npz_file)['arr_0'] + + def __len__(self): + return self._length + + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() + + bucket = self.buckets[self.buckets_indices[index].bucket_index] + image_index = self.buckets_indices[index].batch_index * self.batch_size + + loss_weights = [] + captions = [] + input_ids_list = [] + latents_list = [] + images = [] + + for image_key in bucket[image_index:image_index + self.batch_size]: + image_info = self.image_data[image_key] + loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) + + # image/latentsを処理する + if image_info.latents is not None: + latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped + image = None + elif image_info.latents_npz is not None: + latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) + latents = torch.FloatTensor(latents) + image = None + else: + # 画像を読み込み、必要ならcropする + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) + im_h, im_w = img.shape[0:2] + + if self.enable_bucket: + img = self.resize_and_trim(img, image_info.bucket_reso) + else: + if face_cx > 0: # 顔位置情報あり + img = self.crop_target(img, face_cx, face_cy, face_w, face_h) + elif im_h > self.height or im_w > self.width: + assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" + if im_h > self.height: + p = random.randint(0, im_h - self.height) + img = img[p:p + self.height] + if im_w > self.width: + p = random.randint(0, im_w - self.width) + img = img[:, p:p + self.width] + + im_h, im_w = img.shape[0:2] + assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" + + # augmentation + if self.aug is not None: + img = self.aug(image=img)['image'] + + latents = None + image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + + images.append(image) + latents_list.append(latents) + + caption = self.process_caption(image_info.caption) + captions.append(caption) + if not self.token_padding_disabled: # this option might be omitted in future + input_ids_list.append(self.get_input_ids(caption)) + + example = {} + example['loss_weights'] = torch.FloatTensor(loss_weights) + + if self.token_padding_disabled: + # padding=True means pad in the batch + example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids + else: + # batch processing seems to be good + example['input_ids'] = torch.stack(input_ids_list) + + if images[0] is not None: + images = torch.stack(images) + images = images.to(memory_format=torch.contiguous_format).float() + else: + images = None + example['images'] = images + + example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None + + if self.debug_dataset: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example + + +class DreamBoothDataset(BaseDataset): + def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) + + assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" + + self.batch_size = batch_size + self.size = min(self.width, self.height) # 短いほう + self.prior_loss_weight = prior_loss_weight + self.latents_cache = None + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( + (self.width, self.height), min_bucket_reso, max_bucket_reso) + else: + self.bucket_resos = [(self.width, self.height)] + self.bucket_aspect_ratios = [self.width / self.height] + + def read_caption(img_path): + # captionの候補ファイル名を作る + base_name = os.path.splitext(img_path)[0] + base_name_face_det = base_name + tokens = base_name.split("_") + if len(tokens) >= 5: + base_name_face_det = "_".join(tokens[:-4]) + cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] + + caption = None + for cap_path in cap_paths: + if os.path.isfile(cap_path): + with open(cap_path, "rt", encoding='utf-8') as f: + try: + lines = f.readlines() + except UnicodeDecodeError as e: + print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") + raise e + assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" + caption = lines[0].strip() + break + return caption + + def load_dreambooth_dir(dir): + if not os.path.isdir(dir): + # print(f"ignore file: {dir}") + return 0, [], [] + + tokens = os.path.basename(dir).split('_') + try: + n_repeats = int(tokens[0]) + except ValueError as e: + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + return 0, [], [] + + caption_by_folder = '_'.join(tokens[1:]) + img_paths = glob_images(dir, "*") + print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") + + # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う + captions = [] + for img_path in img_paths: + cap_for_img = read_caption(img_path) + captions.append(caption_by_folder if cap_for_img is None else cap_for_img) + + return n_repeats, img_paths, captions + + print("prepare train images.") + train_dirs = os.listdir(train_data_dir) + num_train_images = 0 + for dir in train_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) + num_train_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, False, img_path) + self.register_image(info) + print(f"{num_train_images} train images with repeating.") + self.num_train_images = num_train_images + + # reg imageは数を数えて学習画像と同じ枚数にする + num_reg_images = 0 + if reg_data_dir: + print("prepare reg images.") + reg_infos: list[ImageInfo] = [] + + reg_dirs = os.listdir(reg_data_dir) + for dir in reg_dirs: + n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) + num_reg_images += n_repeats * len(img_paths) + for img_path, caption in zip(img_paths, captions): + info = ImageInfo(img_path, n_repeats, caption, True, img_path) + reg_infos.append(info) + + print(f"{num_reg_images} reg images.") + if num_train_images < num_reg_images: + print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") + + if num_reg_images == 0: + print("no regularization images / 正則化画像が見つかりませんでした") + else: + # num_repeatsを計算する:どうせ大した数ではないのでループで処理する + n = 0 + first_loop = True + while n < num_train_images: + for info in reg_infos: + if first_loop: + self.register_image(info) + n += info.num_repeats + else: + info.num_repeats += 1 + n += 1 + if n >= num_train_images: + break + first_loop = False + + self.num_reg_images = num_reg_images + + +class FineTuningDataset(BaseDataset): + def __init__(self, json_file_name, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, flip_aug, color_aug, face_crop_aug_range, random_crop, dataset_repeats, debug_dataset) -> None: + super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, + resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) + + # メタデータを読み込む + if os.path.exists(json_file_name): + print(f"loading existing metadata: {json_file_name}") + with open(json_file_name, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + raise ValueError(f"no metadata / メタデータファイルがありません: {json_file_name}") + + self.metadata = metadata + self.train_data_dir = train_data_dir + self.batch_size = batch_size + + for image_key, img_md in metadata.items(): + # path情報を作る + if os.path.exists(image_key): + abs_path = image_key + else: + # わりといい加減だがいい方法が思いつかん + abs_path = glob_images(train_data_dir, image_key) + assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" + abs_path = abs_path[0] + + caption = img_md.get('caption') + tags = img_md.get('tags') + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" + + image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) + image_info.image_size = img_md.get('train_resolution') + + if not self.color_aug: + # if npz exists, use them + image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) + + self.register_image(image_info) + self.num_train_images = len(metadata) * dataset_repeats + self.num_reg_images = 0 + + # check existence of all npz files + if not self.color_aug: + npz_any = False + npz_all = True + for image_info in self.image_data.values(): + has_npz = image_info.latents_npz is not None + npz_any = npz_any or has_npz + + if self.flip_aug: + has_npz = has_npz and image_info.latents_npz_flipped is not None + npz_all = npz_all and has_npz + + if npz_any and not npz_all: + break + + if not npz_any: + print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") + elif not npz_all: + print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") + for image_info in self.image_data.values(): + image_info.latents_npz = image_info.latents_npz_flipped = None + + # check min/max bucket size + sizes = set() + resos = set() + for image_info in self.image_data.values(): + if image_info.image_size is None: + sizes = None # not calculated + break + sizes.add(image_info.image_size[0]) + sizes.add(image_info.image_size[1]) + resos.add(tuple(image_info.image_size)) + + if sizes is None: + assert resolution is not None, "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください" + + self.enable_bucket = enable_bucket + if self.enable_bucket: + assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( + (self.width, self.height), min_bucket_reso, max_bucket_reso) + else: + self.bucket_resos = [(self.width, self.height)] + self.bucket_aspect_ratios = [self.width / self.height] + else: + if not enable_bucket: + print("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします") + print("using bucket info in metadata / メタデータ内のbucket情報を使います") + self.enable_bucket = True + self.bucket_resos = list(resos) + self.bucket_resos.sort() + self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] + + def image_key_to_npz_file(self, image_key): + base_name = os.path.splitext(image_key)[0] + npz_file_norm = base_name + '.npz' + + if os.path.exists(npz_file_norm): + # image_key is full path + npz_file_flip = base_name + '_flip.npz' + if not os.path.exists(npz_file_flip): + npz_file_flip = None + return npz_file_norm, npz_file_flip + + # image_key is relative path + npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') + npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + + if not os.path.exists(npz_file_norm): + npz_file_norm = None + npz_file_flip = None + elif not os.path.exists(npz_file_flip): + npz_file_flip = None + + return npz_file_norm, npz_file_flip + + +def debug_dataset(train_dataset): + print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") + print("Escape for exit. / Escキーで中断、終了します") + k = 0 + for example in train_dataset: + if example['latents'] is not None: + print("sample has latents from npz file") + for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): + print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') + if example['images'] is not None: + im = example['images'][j] + im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) + im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c + im = im[:, :, ::-1] # RGB -> BGR (OpenCV) + cv2.imshow("img", im) + k = cv2.waitKey() + cv2.destroyAllWindows() + if k == 27: + break + if k == 27 or example['images'] is None: + break + + +def glob_images(dir, base): + img_paths = [] + for ext in IMAGE_EXTENSIONS: + img_paths.extend(glob.glob(os.path.join(dir, base + ext))) + return img_paths + +# endregion + + +# region モジュール入れ替え部 +""" +高速化のためのモジュール入れ替え +""" + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.function.Function): + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = (q.shape[-1] ** -0.5) + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) + + p = exp_attn_weights / lc + + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() + + +def replace_unet_cross_attn_to_memory_efficient(): + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction + + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + + context = context if context is not None else x + context = context.to(x.dtype) + + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, 'b h n d -> b n (h d)') + + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_flash_attn + + +def replace_unet_cross_attn_to_xformers(): + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") + + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + + context = default(context, x) + context = context.to(x.dtype) + + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context + + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + + diffusers.models.attention.CrossAttention.forward = forward_xformers +# endregion + + +# region arguments + +def add_sd_models_arguments(parser: argparse.ArgumentParser): + # for pretrained models + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + + +def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model / 学習後のモデル出力先ディレクトリ") + parser.add_argument("--output_name", type=str, default=None, + help="base name of trained model file / 学習後のモデルの拡張子を除くファイル名") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") + + parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--vae", type=str, default=None, + help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") + + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + + if support_dreambooth: + # DreamBooth training + parser.add_argument("--prior_loss_weight", type=float, default=1.0, + help="loss weight for regularization images / 正則化画像のlossの重み") + + +def verify_training_args(args: argparse.Namespace): + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + + +def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): + # dataset common + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") + parser.add_argument("--caption_extention", type=str, default=None, + help="extension of caption files (backward compatibility) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") + parser.add_argument("--keep_tokens", type=int, default=None, + help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") + parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") + parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") + parser.add_argument("--face_crop_aug_range", type=str, default=None, + help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") + parser.add_argument("--random_crop", action="store_true", + help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--resolution", type=str, default=None, + help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") + parser.add_argument("--cache_latents", action="store_true", + help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") + parser.add_argument("--enable_bucket", action="store_true", + help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") + parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") + parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + + if support_dreambooth: + # DreamBooth dataset + parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") + + if support_caption: + # caption dataset + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--dataset_repeats", type=int, default=1, + help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") + + +def add_sd_saving_arguments(parser: argparse.ArgumentParser): + parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + +# endregion + +# region utils + + +def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): + # backward compatibility + if args.caption_extention is not None: + args.caption_extension = args.caption_extention + args.caption_extention = None + + if args.cache_latents: + assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" + + # assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" + if args.resolution is not None: + args.resolution = tuple([int(r) for r in args.resolution.split(',')]) + if len(args.resolution) == 1: + args.resolution = (args.resolution[0], args.resolution[0]) + assert len(args.resolution) == 2, \ + f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + + if args.face_crop_aug_range is not None: + args.face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) + assert len(args.face_crop_aug_range) == 2, \ + f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" + else: + args.face_crop_aug_range = None + + if support_metadata: + if args.in_json is not None and args.color_aug: + print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") + + +def load_tokenizer(args: argparse.Namespace): + print("prepare tokenizer") + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + if args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + return tokenizer + + +def prepare_accelerator(args: argparse.Namespace): + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, + log_with=log_with, logging_dir=logging_dir) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + return accelerator, unwrap_model + + +def prepare_dtype(args: argparse.Namespace): + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + return weight_dtype, save_dtype + + +def load_target_model(args: argparse.Namespace, weight_dtype): + load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) # determine SD or Diffusers + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + else: + print("load Diffusers pretrained models") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + text_encoder = pipe.text_encoder + vae = pipe.vae + unet = pipe.unet + del pipe + + # VAEを読み込む + if args.vae is not None: + vae = model_util.load_vae(args.vae, weight_dtype) + print("additional VAE loaded") + + return text_encoder, vae, unet, load_stable_diffusion_format + + +def patch_accelerator_for_fp16_training(accelerator): + org_unscale_grads = accelerator.scaler._unscale_grads_ + + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) + + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + + +def get_hidden_states(args: argparse.Namespace, input_ids, tokenizer, text_encoder, weight_dtype=None): + # with no_token_padding, the length is not max length, return result immediately + if input_ids.size()[-1] != tokenizer.model_max_length: + return text_encoder(input_ids)[0] + + b_size = input_ids.size()[0] + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 + + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + if weight_dtype is not None: + # this is required for additional network training + encoder_hidden_states = encoder_hidden_states.to(weight_dtype) + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) + + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) + + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) + else: + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) + + return encoder_hidden_states + + +def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch): + model_name = DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + ckpt_name = EPOCH_FILE_NAME.format(model_name, epoch) + (".safetensors" if use_safetensors else ".ckpt") + return model_name, ckpt_name + + +def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): + saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs + remove_epoch_no = None + if saving: + os.makedirs(args.output_dir, exist_ok=True) + save_func() + + if args.save_last_n_epochs is not None: + remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs + remove_old_func(remove_epoch_no) + return saving, remove_epoch_no + + +def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): + epoch_no = epoch + 1 + model_name, ckpt_name = get_epoch_ckpt_name(args, use_safetensors, epoch_no) + + if save_stable_diffusion_format: + def save_sd(): + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch_no, global_step, save_dtype, vae) + + def remove_sd(old_epoch_no): + _, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no) + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + save_func = save_sd + remove_old_func = remove_sd + else: + def save_du(): + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no)) + print(f"saving model: {out_dir}") + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) + + def remove_du(old_epoch_no): + out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no)) + if os.path.exists(out_dir_old): + print(f"removing old model: {out_dir_old}") + shutil.rmtree(out_dir_old) + + save_func = save_du + remove_old_func = remove_du + + saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) + if saving and args.save_state: + save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) + + +def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) + if remove_epoch_no is not None: + state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) + if os.path.exists(state_dir_old): + print(f"removing old state: {state_dir_old}") + shutil.rmtree(state_dir_old) + + +def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae): + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + + if save_stable_diffusion_format: + os.makedirs(args.output_dir, exist_ok=True) + + ckpt_name = model_name + (".safetensors" if use_safetensors else ".ckpt") + ckpt_file = os.path.join(args.output_dir, ckpt_name) + + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_path, epoch, global_step, save_dtype, vae) + else: + out_dir = os.path.join(args.output_dir, model_name) + os.makedirs(out_dir, exist_ok=True) + + print(f"save trained model as Diffusers to {out_dir}") + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_path, vae=vae, use_safetensors=use_safetensors) + + +def save_state_on_train_end(args: argparse.Namespace, accelerator): + print("saving last state.") + os.makedirs(args.output_dir, exist_ok=True) + model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name))) + + +# endregion diff --git a/networks/extract_lora_from_models.py b/networks/extract_lora_from_models.py index ae586f1..c882e88 100644 --- a/networks/extract_lora_from_models.py +++ b/networks/extract_lora_from_models.py @@ -155,4 +155,4 @@ if __name__ == '__main__': parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") args = parser.parse_args() - svd(args) \ No newline at end of file + svd(args) diff --git a/train_db.py b/train_db.py index 75186f5..8c9cdb9 100644 --- a/train_db.py +++ b/train_db.py @@ -1,685 +1,74 @@ -# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします -# (c) 2022 Kohya S. @kohya_ss - -# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images, -# enable reg images in fine-tuning, add dataset_repeats option -# v8: supports Diffusers 0.7.2 -# v9: add bucketing option -# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth -# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization -# add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface -# support save_ever_n_epochs/save_state in DiffUsers model -# fix the issue that prior_loss_weight is applyed to train images -# v12: stop train text encode, tqdm smoothing -# v13: bug fix -# v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae -# v15: model_util update -# v16: support Diffusers 0.10.0 (v-parameterization training, safetensors in Diffusers) and accelerate 0.15.0 -# v17: add fp16 gradient training (experimental) -# v18: add save_model_as option +# DreamBooth training +# XXX dropped option: fine_tune import gc import time -from torch.autograd.function import Function import argparse -import glob import itertools import math import os -import random from tqdm import tqdm import torch -from torchvision import transforms -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import albumentations as albu -import numpy as np -from PIL import Image -import cv2 -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336" - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" -LAST_DIFFUSERS_DIR_NAME = "last" - - -# region dataset - -class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset): - def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None: - super().__init__() - - self.batch_size = batch_size - self.fine_tuning = fine_tuning - self.train_img_path_captions = train_img_path_captions - self.reg_img_path_captions = reg_img_path_captions - self.tokenizer = tokenizer - self.width, self.height = resolution - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.face_crop_aug_range = face_crop_aug_range - self.random_crop = random_crop - self.debug_dataset = debug_dataset - self.shuffle_caption = shuffle_caption - self.disable_padding = disable_padding - self.latents_cache = None - self.enable_bucket = False - - # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - # albu.RandomBrightnessContrast(0.05, 0.05, p=.2), - albu.HueSaturationValue(5, 8, 0, p=.2), - # albu.RGBShift(5, 5, 5, p=.1), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None - - self.num_train_images = len(self.train_img_path_captions) - self.num_reg_images = len(self.reg_img_path_captions) - - self.enable_reg_images = self.num_reg_images > 0 - - if self.enable_reg_images and self.num_train_images < self.num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - self.image_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] - ) - - # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size): - self.enable_bucket = enable_bucket - - cache_latents = vae is not None - if cache_latents: - if enable_bucket: - print("cache latents with bucketing") - else: - print("cache latents") - else: - if enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) - - # 画像の解像度、latentをあらかじめ取得する - img_ar_errors = [] - self.size_lat_cache = {} - for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions): - if image_path in self.size_lat_cache: - continue - - image = self.load_image(image_path)[0] - image_height, image_width = image.shape[0:2] - - if not enable_bucket: - # assert image_width == self.width and image_height == self.height, \ - # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}" - reso = (self.width, self.height) - else: - # bucketを決める - aspect_ratio = image_width / image_height - ar_errors = bucket_aspect_ratios - aspect_ratio - bucket_id = np.abs(ar_errors).argmin() - reso = bucket_resos[bucket_id] - ar_error = ar_errors[bucket_id] - img_ar_errors.append(ar_error) - - if cache_latents: - image = self.resize_and_trim(image, reso) - - # latentを取得する - if cache_latents: - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - else: - latents = None - - self.size_lat_cache[image_path] = (reso, latents) - - # 画像をbucketに分割する - self.buckets = [[] for _ in range(len(bucket_resos))] - reso_to_index = {} - for i, reso in enumerate(bucket_resos): - reso_to_index[reso] = i - - def split_to_buckets(is_reg, img_path_captions): - for image_path, caption in img_path_captions: - reso, _ = self.size_lat_cache[image_path] - bucket_index = reso_to_index[reso] - self.buckets[bucket_index].append((is_reg, image_path, caption)) - - split_to_buckets(False, self.train_img_path_captions) - - if self.enable_reg_images: - l = [] - while len(l) < len(self.train_img_path_captions): - l += self.reg_img_path_captions - l = l[:len(self.train_img_path_captions)] - split_to_buckets(True, l) - - if enable_bucket: - print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数") - for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(imgs)}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}") - - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - # どのサイズにリサイズするか→トリミングする方向で - def resize_and_trim(self, image, reso): - image_height, image_width = image.shape[0:2] - ar_img = image_width / image_height - ar_reso = reso[0] / reso[1] - if ar_img > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - elif resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ - f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへを切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size - - latents_list = [] - images = [] - captions = [] - loss_weights = [] - - for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]: - loss_weights.append(self.prior_loss_weight if is_reg else 1.0) - - # image/latentsを処理する - reso, latents = self.size_lat_cache[image_path] - - if latents is None: - # 画像を読み込み必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image(image_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.resize_and_trim(img, reso) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - else: - image = None - - images.append(image) - latents_list.append(latents) - - # captionを処理する - if self.shuffle_caption: # captionのshuffleをする - tokens = caption.strip().split(",") - random.shuffle(tokens) - caption = ",".join(tokens).strip() - captions.append(caption) - - # input_idsをpadしてTensor変換 - if self.disable_padding: - # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?) - input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids - else: - # paddingする - input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = input_ids - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None - if self.debug_dataset: - example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]] - example['captions'] = captions - return example -# endregion - - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format - # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format - del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion +import library.train_util as train_util +from library.train_util import DreamBoothDataset def collate_fn(examples): return examples[0] -# def load_clip_l14_336(dtype): -# print(f"loading CLIP: {CLIP_ID_L14_336}") -# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype) -# return text_encoder - - def train(args): - if args.caption_extention is not None: - args.caption_extension = args.caption_extention - args.caption_extention = None + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, False) - fine_tuning = args.fine_tuning cache_latents = args.cache_latents - # latentsをキャッシュする場合のオプション設定を確認する - if cache_latents: - assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません" + if args.seed is not None: + set_seed(args.seed) # 乱数系列を初期化する - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") + tokenizer = train_util.load_tokenizer(args) - # モデル形式のオプション設定を確認する: - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) + train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, + tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) + if args.no_token_padding: + train_dataset.disable_token_padding() + train_dataset.make_buckets() + if args.debug_dataset: + train_util.debug_dataset(train_dataset) + return + + # acceleratorを準備する + print("prepare accelerator") + + if args.gradient_accumulation_steps > 1: + print(f"gradient_accumulation_steps is {args.gradient_accumulation_steps}. accelerate does not support gradient_accumulation_steps when training multiple models (U-Net and Text Encoder), so something might be wrong") + print( + f"gradient_accumulation_stepsが{args.gradient_accumulation_steps}に設定されています。accelerateは複数モデル(U-NetおよびText Encoder)の学習時にgradient_accumulation_stepsをサポートしていないため結果は未知数です") + + accelerator, unwrap_model = train_util.prepare_accelerator(args) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype, save_dtype = train_util.prepare_dtype(args) + + # モデルを読み込む + text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype) + + # verify load/save model formats if load_stable_diffusion_format: src_stable_diffusion_ckpt = args.pretrained_model_name_or_path src_diffusers_model_path = None else: src_stable_diffusion_ckpt = None src_diffusers_model_path = args.pretrained_model_name_or_path - + if args.save_model_as is None: save_stable_diffusion_format = load_stable_diffusion_format use_safetensors = args.use_safetensors @@ -687,204 +76,8 @@ def train(args): save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) - # 乱数系列を初期化する - if args.seed is not None: - set_seed(args.seed) - - # 学習データを用意する - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + args.caption_extension, base_name_face_det + args.caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - return 0, [] - - caption_by_folder = '_'.join(tokens[1:]) - - print(f"found directory {n_repeats}_{caption_by_folder}") - - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う(v11から仕様変更した) - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - - return n_repeats, list(zip(img_paths, captions)) - - print("prepare train images.") - train_img_path_captions = [] - - if fine_tuning: - img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \ - glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) - for img_path in tqdm(img_paths): - caption = read_caption(img_path) - assert caption is not None and len( - caption) > 0, f"no caption for image. check caption_extension option / キャプションファイルが見つからないかcaptionが空です。caption_extensionオプションを確認してください: {img_path}" - - train_img_path_captions.append((img_path, caption)) - - if args.dataset_repeats is not None: - l = [] - for _ in range(args.dataset_repeats): - l.extend(train_img_path_captions) - train_img_path_captions = l - else: - train_dirs = os.listdir(args.train_data_dir) - for dir in train_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir)) - for _ in range(n_repeats): - train_img_path_captions.extend(img_caps) - print(f"{len(train_img_path_captions)} train images with repeating.") - - reg_img_path_captions = [] - if args.reg_data_dir: - print("prepare reg images.") - reg_dirs = os.listdir(args.reg_data_dir) - for dir in reg_dirs: - n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir)) - for _ in range(n_repeats): - reg_img_path_captions.extend(img_caps) - print(f"{len(reg_img_path_captions)} reg images.") - - # データセットを準備する - resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(resolution) == 1: - resolution = (resolution[0], resolution[0]) - assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.enable_bucket: - assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください" - assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください" - - if args.face_crop_aug_range is not None: - face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len( - face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - face_crop_aug_range = None - - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - - print("prepare dataset") - train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, - args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, - args.shuffle_caption, args.no_token_padding, args.debug_dataset) - - if args.debug_dataset: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, - args.max_bucket_reso) # デバッグ用にcacheなしで作る - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") - for example in train_dataset: - for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']): - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}') - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27: - break - return - - # acceleratorを準備する - # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする - print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) - - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 - - # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # # 置換するCLIPを読み込む - # if args.replace_clip_l14_336: - # text_encoder = load_clip_l14_336(weight_dtype) - # print(f"large clip {CLIP_ID_L14_336} is loaded") - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") - # モデルに xformers とか memory efficient attention を組み込む - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) # 学習を準備する if cache_latents: @@ -892,23 +85,31 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso) + train_dataset.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - else: - train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso) - vae.requires_grad_(False) - vae.eval() + # 学習を準備する:モデルを適切な状態にする + if args.stop_text_encoder_training is None: + args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end + + train_text_encoder = args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 - text_encoder.requires_grad_(True) + text_encoder.requires_grad_(train_text_encoder) + if not train_text_encoder: + print("Text Encoder is not trained.") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() + if not cache_latents: + vae.requires_grad_(False) + vae.eval() + vae.to(accelerator.device, dtype=weight_dtype) + # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") @@ -923,7 +124,10 @@ def train(args): else: optimizer_class = torch.optim.AdamW - trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + trainable_params = unet.parameters() # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 optimizer = optimizer_class(trainable_params, lr=args.learning_rate) @@ -946,20 +150,18 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + if train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) - if not cache_latents: - vae.to(accelerator.device, dtype=weight_dtype) + if not train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -967,50 +169,46 @@ def train(args): accelerator.load_state(args.resume) # epoch数を計算する - num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader)) + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # 学習する - total_batch_size = args.train_batch_size # * accelerator.num_processes + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") - print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 - # v12で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀')  noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("dreambooth") - # 以下 train_dreambooth.py からほぼコピペ for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") # 指定したステップ数までText Encoderを学習する:epoch最初の状態 - train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training unet.train() - if train_text_encoder: + # train==True is required to enable gradient_checkpointing + if args.gradient_checkpointing or global_step < args.stop_text_encoder_training: text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): # 指定したステップ数でText Encoderの学習を止める - stop_text_encoder_training = args.stop_text_encoder_training is not None and global_step == args.stop_text_encoder_training - if stop_text_encoder_training: + if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") - text_encoder.train(False) + if not args.gradient_checkpointing: + text_encoder.train(False) text_encoder.requires_grad_(False) with accelerator.accumulate(unet): @@ -1026,6 +224,12 @@ def train(args): noise = torch.randn_like(latents, device=latents.device) b_size = latents.shape[0] + # Get the text embedding for conditioning + with torch.set_grad_enabled(global_step < args.stop_text_encoder_training): + input_ids = batch["input_ids"].to(accelerator.device) + encoder_hidden_states = train_util.get_hidden_states( + args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype) + # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() @@ -1034,20 +238,11 @@ def train(args): # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - if args.clip_skip is None: - encoder_hidden_states = text_encoder(batch["input_ids"])[0] - else: - enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -1062,7 +257,10 @@ def train(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + if train_text_encoder: + params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters())) + else: + params_to_clip = unet.parameters() accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) optimizer.step() @@ -1094,23 +292,9 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - if save_stable_diffusion_format: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), - src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) - else: - out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), - unwrap_model(unet), src_diffusers_model_path, - use_safetensors=use_safetensors) - - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_epoch_end(args, accelerator, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, num_train_epochs, global_step, unwrap_model(text_encoder), unwrap_model(unet), vae) is_main_process = accelerator.is_main_process if is_main_process: @@ -1120,110 +304,29 @@ def train(args): accelerator.end_training() if args.save_state: - print("saving last state.") - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - if save_stable_diffusion_format: - ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) - print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") - model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, - src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) - else: - print(f"save trained model as Diffusers to {args.output_dir}") - out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, src_diffusers_model_path, - use_safetensors=use_safetensors) + src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path + train_util.save_sd_model_on_train_end(args, src_path, save_stable_diffusion_format, use_safetensors, + save_dtype, epoch, global_step, text_encoder, unet, vae) print("model saved.") if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - # parser.add_argument("--replace_clip_l14_336", action='store_true', - # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える") - parser.add_argument("--fine_tuning", action="store_true", - help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--caption_extention", type=str, default=None, - help="extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--dataset_repeats", type=int, default=None, - help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") - parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], - help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") - parser.add_argument("--use_safetensors", action='store_true', - help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, False) + train_util.add_training_arguments(parser, True) + train_util.add_sd_saving_arguments(parser) + parser.add_argument("--no_token_padding", action="store_true", help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)") parser.add_argument("--stop_text_encoder_training", type=int, default=None, - help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--train_batch_size", type=int, default=1, - help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない") args = parser.parse_args() - train(args) \ No newline at end of file + train(args) diff --git a/train_db_README-ja.md b/train_db_README-ja.md new file mode 100644 index 0000000..53ee715 --- /dev/null +++ b/train_db_README-ja.md @@ -0,0 +1,296 @@ +DreamBoothのガイドです。LoRA等の追加ネットワークの学習にも同じ手順を使います。 + +# 概要 + +スクリプトの主な機能は以下の通りです。 + +- 8bit Adam optimizerおよびlatentのキャッシュによる省メモリ化(ShivamShrirao氏版と同様)。 +- xformersによる省メモリ化。 +- 512x512だけではなく任意サイズでの学習。 +- augmentationによる品質の向上。 +- DreamBoothだけではなくText Encoder+U-Netのfine tuningに対応。 +- StableDiffusion形式でのモデルの読み書き。 +- Aspect Ratio Bucketing。 +- Stable Diffusion v2.0対応。 + +# 学習の手順 + +## step 1. 環境整備 + +このリポジトリのREADMEを参照してください。 + + +## step 2. identifierとclassを決める + +学ばせたい対象を結びつける単語identifierと、対象の属するclassを決めます。 + +(instanceなどいろいろな呼び方がありますが、とりあえず元の論文に合わせます。) + +以下ごく簡単に説明します(詳しくは調べてください)。 + +classは学習対象の一般的な種別です。たとえば特定の犬種を学ばせる場合には、classはdogになります。アニメキャラならモデルによりboyやgirl、1boyや1girlになるでしょう。 + +identifierは学習対象を識別して学習するためのものです。任意の単語で構いませんが、元論文によると「tokinizerで1トークンになる3文字以下でレアな単語」が良いとのことです。 + +identifierとclassを使い、たとえば「shs dog」などでモデルを学習することで、学習させたい対象をclassから識別して学習できます。 + +画像生成時には「shs dog」とすれば学ばせた犬種の画像が生成されます。 + +(identifierとして私が最近使っているものを参考までに挙げると、``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny`` などです。) + +## step 3. 学習用画像の準備 +学習用画像を格納するフォルダを作成します。 __さらにその中に__ 、以下の名前でディレクトリを作成します。 + +``` +<繰り返し回数>_ +``` + +間の``_``を忘れないでください。 + +繰り返し回数は、正則化画像と枚数を合わせるために指定します(後述します)。 + +たとえば「sls frog」というプロンプトで、データを20回繰り返す場合、「20_sls frog」となります。以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) + +## step 4. 正則化画像の準備 +正則化画像を使う場合の手順です。使わずに学習することもできます(正則化画像を使わないと区別ができなくなるので対象class全体が影響を受けます)。 + +正則化画像を格納するフォルダを作成します。 __さらにその中に__ ``<繰り返し回数>_`` という名前でディレクトリを作成します。 + +たとえば「frog」というプロンプトで、データを繰り返さない(1回だけ)場合、以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) + +繰り返し回数は「 __学習用画像の繰り返し回数×学習用画像の枚数≧正則化画像の繰り返し回数×正則化画像の枚数__ 」となるように指定してください。 + +(1 epochのデータ数が「学習用画像の繰り返し回数×学習用画像の枚数」となります。正則化画像の枚数がそれより多いと、余った部分の正則化画像は使用されません。) + +## step 5. 学習の実行 +スクリプトを実行します。最大限、メモリを節約したコマンドは以下のようになります(実際には1行で入力します)。 + +※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。 + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --train_data_dir=<学習用データのディレクトリ> + --reg_data_dir=<正則化画像のディレクトリ> + --output_dir=<学習したモデルの出力先ディレクトリ> + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=1 + --learning_rate=1e-6 + --max_train_steps=1600 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents + --gradient_checkpointing +``` + +num_cpu_threads_per_processにはCPUコア数を指定するとよいようです。 + +pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになります(save_model_asオプションで変更できます)。 + +prior_loss_weightは正則化画像のlossの重みです。通常は1.0を指定します。 + +resolutionは画像のサイズ(解像度、幅と高さ)になります。bucketing(後述)を用いない場合、学習用画像、正則化画像はこのサイズとしてください。 + +train_batch_sizeは学習時のバッチサイズです。max_train_stepsを1600とします。学習率learning_rateは、diffusers版では5e-6ですがStableDiffusion版は1e-6ですのでここでは1e-6を指定しています。 + +省メモリ化のためmixed_precision="bf16"(または"fp16")、およびgradient_checkpointing を指定します。 + +xformersオプションを指定し、xformersのCrossAttentionを用います。xformersをインストールしていない場合、エラーとなる場合(mixed_precisionなしの場合、私の環境ではエラーとなりました)、代わりにmem_eff_attnオプションを指定すると省メモリ版CrossAttentionを使用します(速度は遅くなります)。 + +省メモリ化のためcache_latentsオプションを指定してVAEの出力をキャッシュします。 + +ある程度メモリがある場合はたとえば以下のように指定します。 + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --train_data_dir=<学習用データのディレクトリ> + --reg_data_dir=<正則化画像のディレクトリ> + --output_dir=<学習したモデルの出力先ディレクトリ> + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=4 + --learning_rate=1e-6 + --max_train_steps=400 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents +``` + +gradient_checkpointingを外し高速化します(メモリ使用量は増えます)。バッチサイズを増やし、高速化と精度向上を図ります。 + +bucketing(後述)を利用しかつaugmentation(後述)を使う場合の例は以下のようになります。 + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> + --train_data_dir=<学習用データのディレクトリ> + --reg_data_dir=<正則化画像のディレクトリ> + --output_dir=<学習したモデルの出力先ディレクトリ> + --resolution=768,512 + --train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800 + --use_8bit_adam --xformers --mixed_precision="bf16" + --save_every_n_epochs=1 --save_state --save_precision="bf16" + --logging_dir=logs + --enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280 + --color_aug --flip_aug --gradient_checkpointing --seed 42 +``` + +### ステップ数について +省メモリ化のため、ステップ当たりの学習回数がtrain_dreambooth.pyの半分になっています(対象の画像と正則化画像を同一のバッチではなく別のバッチに分割して学習するため)。 +元のDiffusers版やXavierXiao氏のStableDiffusion版とほぼ同じ学習を行うには、ステップ数を倍にしてください。 + +(shuffle=Trueのため厳密にはデータの順番が変わってしまいますが、学習には大きな影響はないと思います。) + +## 学習したモデルで画像生成する + +学習が終わると指定したフォルダにlast.ckptという名前でcheckpointが出力されます(DiffUsers版モデルを学習した場合はlastフォルダになります)。 + +v1.4/1.5およびその他の派生モデルの場合、このモデルでAutomatic1111氏のWebUIなどで推論できます。models\Stable-diffusionフォルダに置いてください。 + +v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述された.yamlファイルが別途必要になります。v2.x baseの場合はv2-inference.yamlを、768/vの場合はv2-inference-v.yamlを、同じフォルダに置き、拡張子の前の部分をモデルと同じ名前にしてください。 + +![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) + +各yamlファイルは[https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion](Stability AIのSD2.0のリポジトリ)にあります。 + +# その他の学習オプション + +## Stable Diffusion 2.0対応 --v2 / --v_parameterization +Hugging Faceのstable-diffusion-2-baseを使う場合はv2オプションを、stable-diffusion-2または768-v-ema.ckptを使う場合はv2とv_parameterizationの両方のオプションを指定してください。 + +なおSD 2.0の学習はText Encoderが大きくなっているためVRAM 12GBでは厳しいようです。 + +Stable Diffusion 2.0では大きく以下の点が変わっています。 + +1. 使用するTokenizer +2. 使用するText Encoderおよび使用する出力層(2.0は最後から二番目の層を使う) +3. Text Encoderの出力次元数(768->1024) +4. U-Netの構造(CrossAttentionのhead数など) +5. v-parameterization(サンプリング方法が変更されているらしい) + +このうちbaseでは1~4が、baseのつかない方(768-v)では1~5が採用されています。1~4を有効にするのがv2オプション、5を有効にするのがv_parameterizationオプションです。 + +## 学習データの確認 --debug_dataset +このオプションを付けることで学習を行う前に事前にどのような画像データ、キャプションで学習されるかを確認できます。Escキーを押すと終了してコマンドラインに戻ります。 + +※Colabなど画面が存在しない環境で実行するとハングするようですのでご注意ください。 + +## Text Encoderの学習を途中から行わない --stop_text_encoder_training +stop_text_encoder_trainingオプションに数値を指定すると、そのステップ数以降はText Encoderの学習を行わずU-Netだけ学習します。場合によっては精度の向上が期待できるかもしれません。 + +(恐らくText Encoderだけ先に過学習することがあり、それを防げるのではないかと推測していますが、詳細な影響は不明です。) + +## VAEを別途読み込んで学習する --vae +vaeオプションにStable Diffusionのcheckpoint、VAEのcheckpointファイル、DiffusesのモデルまたはVAE(ともにローカルまたはHugging FaceのモデルIDが指定できます)のいずれかを指定すると、そのVAEを使って学習します(latentsのキャッシュ時または学習中のlatents取得時)。 +保存されるモデルはこのVAEを組み込んだものになります。 + +## 学習途中での保存 --save_every_n_epochs / --save_state / --resume +save_every_n_epochsオプションに数値を指定すると、そのエポックごとに学習途中のモデルを保存します。 + +save_stateオプションを同時に指定すると、optimizer等の状態も含めた学習状態を合わせて保存します(checkpointから学習再開するのに比べて、精度の向上、学習時間の短縮が期待できます)。学習状態は保存先フォルダに"epoch-??????-state"(??????はエポック数)という名前のフォルダで出力されます。長時間にわたる学習時にご利用ください。 + +保存された学習状態から学習を再開するにはresumeオプションを使います。学習状態のフォルダを指定してください。 + +なおAcceleratorの仕様により(?)、エポック数、global stepは保存されておらず、resumeしたときにも1からになりますがご容赦ください。 + +## Tokenizerのパディングをしない --no_token_padding +no_token_paddingオプションを指定するとTokenizerの出力をpaddingしません(Diffusers版の旧DreamBoothと同じ動きになります)。 + +## 任意サイズの画像での学習 --resolution +正方形以外で学習できます。resolutionに「448,640」のように「幅,高さ」で指定してください。幅と高さは64で割り切れる必要があります。学習用画像、正則化画像のサイズを合わせてください。 + +個人的には縦長の画像を生成することが多いため「448,640」などで学習することもあります。 + +## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso +enable_bucketオプションを指定すると有効になります。Stable Diffusionは512x512で学習されていますが、それに加えて256x768や384x640といった解像度でも学習します。 + +このオプションを指定した場合は、学習用画像、正則化画像を特定の解像度に統一する必要はありません。いくつかの解像度(アスペクト比)から最適なものを選び、その解像度で学習します。 +解像度は64ピクセル単位のため、元画像とアスペクト比が完全に一致しない場合がありますが、その場合は、はみ出した部分がわずかにトリミングされます。 + +解像度の最小サイズをmin_bucket_resoオプションで、最大サイズをmax_bucket_resoで指定できます。デフォルトはそれぞれ256、1024です。 +たとえば最小サイズに384を指定すると、256x1024や320x768などの解像度は使わなくなります。 +解像度を768x768のように大きくした場合、最大サイズに1280などを指定しても良いかもしれません。 + +なおAspect Ratio Bucketingを有効にするときには、正則化画像についても、学習用画像と似た傾向の様々な解像度を用意した方がいいかもしれません。 + +(ひとつのバッチ内の画像が学習用画像、正則化画像に偏らなくなるため。そこまで大きな影響はないと思いますが……。) + +## augmentation --color_aug / --flip_aug +augmentationは学習時に動的にデータを変化させることで、モデルの性能を上げる手法です。color_augで色合いを微妙に変えつつ、flip_augで左右反転をしつつ、学習します。 + +動的にデータを変化させるため、cache_latentsオプションと同時に指定できません。 + +## 保存時のデータ精度の指定 --save_precision +save_precisionオプションにfloat、fp16、bf16のいずれかを指定すると、その形式でcheckpointを保存します(Stable Diffusion形式で保存する場合のみ)。checkpointのサイズを削減したい場合などにお使いください。 + +## 任意の形式で保存する --save_model_as +モデルの保存形式を指定します。ckpt、safetensors、diffusers、diffusers_safetensorsのいずれかを指定してください。 + +Stable Diffusion形式(ckptまたはsafetensors)を読み込み、Diffusers形式で保存する場合、不足する情報はHugging Faceからv1.5またはv2.1の情報を落としてきて補完します。 + +## 学習ログの保存 --logging_dir / --log_prefix +logging_dirオプションにログ保存先フォルダを指定してください。TensorBoard形式のログが保存されます。 + +たとえば--logging_dir=logsと指定すると、作業フォルダにlogsフォルダが作成され、その中の日時フォルダにログが保存されます。 +また--log_prefixオプションを指定すると、日時の前に指定した文字列が追加されます。「--logging_dir=logs --log_prefix=db_style1_」などとして識別用にお使いください。 + +TensorBoardでログを確認するには、別のコマンドプロンプトを開き、作業フォルダで以下のように入力します(tensorboardはDiffusersのインストール時にあわせてインストールされると思いますが、もし入っていないならpip install tensorboardで入れてください)。 + +``` +tensorboard --logdir=logs +``` + +その後ブラウザを開き、http://localhost:6006/ へアクセスすると表示されます。 + +## 学習率のスケジューラ関連の指定 --lr_scheduler / --lr_warmup_steps +lr_schedulerオプションで学習率のスケジューラをlinear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmupから選べます。デフォルトはconstantです。lr_warmup_stepsでスケジューラのウォームアップ(だんだん学習率を変えていく)ステップ数を指定できます。詳細については各自お調べください。 + +## 勾配をfp16とした学習(実験的機能) --full_fp16 +full_fp16オプションを指定すると勾配を通常のfloat32からfloat16(fp16)に変更して学習します(mixed precisionではなく完全なfp16学習になるようです)。 +これによりSD1.xの512x512サイズでは8GB未満、SD2.xの512x512サイズで12GB未満のVRAM使用量で学習できるようです。 + +あらかじめaccelerate configでfp16を指定し、オプションで ``mixed_precision="fp16"`` としてください(bf16では動作しません)。 + +メモリ使用量を最小化するためには、xformers、use_8bit_adam、cache_latents、gradient_checkpointingの各オプションを指定し、train_batch_sizeを1としてください。 + +(余裕があるようならtrain_batch_sizeを段階的に増やすと若干精度が上がるはずです。) + +PyTorchのソースにパッチを当てて無理やり実現しています(PyTorch 1.12.1と1.13.0で確認)。精度はかなり落ちますし、途中で学習失敗する確率も高くなります。 +学習率やステップ数の設定もシビアなようです。それらを認識したうえで自己責任でお使いください。 + +# その他の学習方法 + +## 複数class、複数対象(identifier)の学習 +方法は単純で、学習用画像のフォルダ内に ``繰り返し回数_ `` のフォルダを複数、正則化画像フォルダにも同様に ``繰り返し回数_`` のフォルダを複数、用意してください。 + +たとえば「sls frog」と「cpc rabbit」を同時に学習する場合、以下のようになります。 + +![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) + +classがひとつで対象が複数の場合、正則化画像フォルダはひとつで構いません。たとえば1girlにキャラAとキャラBがいる場合は次のようにします。 + +- train_girls + - 10_sls 1girl + - 10_cpc 1girl +- reg_girls + - 1_1girl + +データ数にばらつきがある場合、繰り返し回数を調整してclass、identifierごとの枚数を統一すると良い結果が得られることがあるようです。 + +## DreamBoothでキャプションを使う +学習用画像、正則化画像のフォルダに、画像と同じファイル名で、拡張子.caption(オプションで変えられます)のファイルを置くと、そのファイルからキャプションを読み込みプロンプトとして学習します。 + +※それらの画像の学習に、フォルダ名(identifier class)は使用されなくなります。 + +各画像にキャプションを付けることで(BLIP等を使っても良いでしょう)、学習したい属性をより明確にできるかもしれません。 + +キャプションファイルの拡張子はデフォルトで.captionです。--caption_extensionで変更できます。--shuffle_captionオプションで学習時のキャプションについて、カンマ区切りの各部分をシャッフルしながら学習します。 + diff --git a/train_db_README.md b/train_db_README.md new file mode 100644 index 0000000..2367d29 --- /dev/null +++ b/train_db_README.md @@ -0,0 +1,295 @@ +A guide to DreamBooth. The same procedure is used for training additional networks such as LoRA. + +# overview + +The main functions of the script are as follows. + +- Memory saving by 8bit Adam optimizer and latent cache (similar to ShivamShirao's version). +- Saved memory by xformers. +- Study in any size, not just 512x512. +- Quality improvement with augmentation. +- Supports fine tuning of Text Encoder+U-Net as well as DreamBooth. +- Read and write models in StableDiffusion format. +- Aspect Ratio Bucketing. +- Supports Stable Diffusion v2.0. + +# learning procedure + +## step 1. Environment improvement + +See the README in this repository. + + +## step 2. Determine identifier and class + +Decide the word identifier that connects the target you want to learn and the class to which the target belongs. + +(There are various names such as instance, but for the time being I will stick to the original paper.) + +Here's a very brief explanation (look it up for more details). + +class is the general type to learn. For example, if you want to learn a specific breed of dog, the class will be dog. Anime characters will be boy, girl, 1boy or 1girl depending on the model. + +The identifier is for identifying and learning the learning target. Any word is fine, but according to the original paper, ``a rare word with 3 letters or less that becomes one token with tokinizer'' is good. + +By using the identifier and class to train the model, for example, "shs dog", you can learn by identifying the object you want to learn from the class. + +When generating an image, if you say "shs dog", an image of the learned dog breed will be generated. + +(For reference, the identifier I use these days is ``shs sts scs cpc coc cic msm usu ici lvl cic dii muk ori hru rik koo yos wny``.) + +## step 3. Prepare images for training +Create a folder to store training images. __In addition, create a directory with the following name: + +``` +_ +``` + +Don't forget the ``_`` between them. + +The number of repetitions is specified to match the number of regularized images (described later). + +For example, at the prompt "sls frog", to repeat the data 20 times, it would be "20_sls frog". It will be as follows. + +![image](https://user-images.githubusercontent.com/52813779/210770636-1c851377-5936-4c15-90b7-8ac8ad6c2074.png) + +## step 4. Preparing regularized images +This is the procedure when using a regularized image. It is also possible to learn without using the regularization image (the whole target class is affected because it is impossible to distinguish without using the regularization image). + +Create a folder to store the regularized images. __In addition, __ create a directory named ``_``. + +For example, with the prompt "frog" and without repeating the data (just once): + +![image](https://user-images.githubusercontent.com/52813779/210770897-329758e5-3675-49f1-b345-c135f1725832.png) + +Specify the number of iterations so that " __ number of iterations of training images x number of training images ≥ number of iterations of regularization images x number of regularization images __". + +(The number of data in one epoch is "number of repetitions of training images x number of training images". If the number of regularization images is more than that, the remaining regularization images will not be used.) + +## step 5. Run training +Run the script. The maximally memory-saving command looks like this (actually typed on one line): + +*The command for learning additional networks such as LoRA is ``train_network.py`` instead of ``train_db.py``. You will also need additional network_\* options, so please refer to LoRA's guide. + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=1 + --learning_rate=1e-6 + --max_train_steps=1600 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents + --gradient_checkpointing +``` + +It seems to be good to specify the number of CPU cores for num_cpu_threads_per_process. + +Specify the model to perform additional training in pretrained_model_name_or_path. You can specify a Stable Diffusion checkpoint file (.ckpt or .safetensors), a model directory on the Diffusers local disk, or a Diffusers model ID (such as "stabilityai/stable-diffusion-2"). The saved model after training will be saved in the same format as the original model by default (can be changed with the save_model_as option). + +prior_loss_weight is the loss weight of the regularized image. Normally, specify 1.0. + +resolution will be the size of the image (resolution, width and height). If bucketing (described later) is not used, use this size for training images and regularization images. + +train_batch_size is the training batch size. Set max_train_steps to 1600. The learning rate learning_rate is 5e-6 in the diffusers version and 1e-6 in the StableDiffusion version, so 1e-6 is specified here. + +Specify mixed_precision="bf16" (or "fp16") and gradient_checkpointing for memory saving. + +Specify the xformers option and use xformers' CrossAttention. If you don't have xformers installed, if you get an error (without mixed_precision, it was an error in my environment), specify the mem_eff_attn option instead to use the memory-saving version of CrossAttention (speed will be slower) . + +Cache VAE output with cache_latents option to save memory. + +If you have a certain amount of memory, specify it as follows, for example. + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --prior_loss_weight=1.0 + --resolution=512 + --train_batch_size=4 + --learning_rate=1e-6 + --max_train_steps=400 + --use_8bit_adam + --xformers + --mixed_precision="bf16" + --cache_latents +``` + +Remove gradient_checkpointing to speed up (memory usage will increase). Increase the batch size to improve speed and accuracy. + +An example of using bucketing (see below) and using augmentation (see below) looks like this: + +``` +accelerate launch --num_cpu_threads_per_process 8 train_db.py + --pretrained_model_name_or_path= + --train_data_dir= + --reg_data_dir= + --output_dir= + --resolution=768,512 + --train_batch_size=20 --learning_rate=5e-6 --max_train_steps=800 + --use_8bit_adam --xformers --mixed_precision="bf16" + --save_every_n_epochs=1 --save_state --save_precision="bf16" + --logging_dir=logs + --enable_bucket --min_bucket_reso=384 --max_bucket_reso=1280 + --color_aug --flip_aug --gradient_checkpointing --seed 42 +``` + +### About the number of steps +To save memory, the number of training steps per step is half that of train_drebooth.py (because the target image and the regularization image are divided into different batches instead of the same batch). +Double the number of steps to get almost the same training as the original Diffusers version and XavierXiao's StableDiffusion version. + +(Strictly speaking, the order of the data changes due to shuffle=True, but I don't think it has a big impact on learning.) + +## Generate an image with the trained model + +Name last.ckpt in the specified folder when learning is completed will output the checkpoint (if you learned the DiffUsers version model, it will be the last folder). + +For v1.4/1.5 and other derived models, this model can be inferred by Automatic1111's WebUI, etc. Place it in the models\Stable-diffusion folder. + +When generating images with WebUI with the v2.x model, a separate .yaml file that describes the model specifications is required. Place v2-inference.yaml for v2.x base and v2-inference-v.yaml for 768/v in the same folder and make the part before the extension the same name as the model. + +![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) + +Each yaml file can be found at [https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion] (Stability AI SD2.0 repository). + +# Other study options + +## Supports Stable Diffusion 2.0 --v2 / --v_parameterization +Specify the v2 option when using Hugging Face's stable-diffusion-2-base, and specify both the v2 and v_parameterization options when using stable-diffusion-2 or 768-v-ema.ckpt. + +In addition, learning SD 2.0 seems to be difficult with VRAM 12GB because the Text Encoder is getting bigger. + +The following points have changed significantly in Stable Diffusion 2.0. + +1. Tokenizer to use +2. Which Text Encoder to use and which output layer to use (2.0 uses the penultimate layer) +3. Output dimensionality of Text Encoder (768->1024) +4. Structure of U-Net (number of heads of CrossAttention, etc.) +5. v-parameterization (the sampling method seems to have changed) + +Among these, 1 to 4 are adopted for base, and 1 to 5 are adopted for the one without base (768-v). Enabling 1-4 is the v2 option, and enabling 5 is the v_parameterization option. + +## check training data --debug_dataset +By adding this option, you can check what kind of image data and captions will be learned in advance before learning. Press Esc to exit and return to the command line. + +*Please note that it seems to hang when executed in an environment where there is no screen such as Colab. + +## Stop training Text Encoder --stop_text_encoder_training +If you specify a numerical value for the stop_text_encoder_training option, after that number of steps, only the U-Net will be trained without training the Text Encoder. In some cases, the accuracy may be improved. + +(Probably only the Text Encoder may overfit first, and I guess that it can be prevented, but the detailed impact is unknown.) + +## Load and learn VAE separately --vae +If you specify either a Stable Diffusion checkpoint, a VAE checkpoint file, a Diffuses model, or a VAE (both of which can specify a local or Hugging Face model ID) in the vae option, that VAE is used for learning (latents when caching or getting latents during learning). +The saved model will incorporate this VAE. + +## save during learning --save_every_n_epochs / --save_state / --resume +Specifying a number for the save_every_n_epochs option saves the model during training every epoch. + +If you specify the save_state option at the same time, the learning state including the state of the optimizer etc. will be saved together (compared to restarting learning from the checkpoint, you can expect to improve accuracy and shorten the learning time). The learning state is output in a folder named "epoch-??????-state" (?????? is the number of epochs) in the destination folder. Please use it when studying for a long time. + +Use the resume option to resume training from a saved training state. Please specify the learning state folder. + +Please note that due to the specifications of Accelerator (?), the number of epochs and global step are not saved, and it will start from 1 even when you resume. + +## No tokenizer padding --no_token_padding +The no_token_padding option does not pad the output of the Tokenizer (same behavior as Diffusers version of old DreamBooth). + +## Training with arbitrary size images --resolution +You can study outside the square. Please specify "width, height" like "448,640" in resolution. Width and height must be divisible by 64. Match the size of the training image and the regularization image. + +Personally, I often generate vertically long images, so I sometimes learn with "448, 640". + +## Aspect Ratio Bucketing --enable_bucket / --min_bucket_reso / --max_bucket_reso +It is enabled by specifying the enable_bucket option. Stable Diffusion is trained at 512x512, but also at resolutions such as 256x768 and 384x640. + +If you specify this option, you do not need to unify the training images and regularization images to a specific resolution. Choose from several resolutions (aspect ratios) and learn at that resolution. +Since the resolution is 64 pixels, the aspect ratio may not be exactly the same as the original image. + +You can specify the minimum size of the resolution with the min_bucket_reso option and the maximum size with the max_bucket_reso. The defaults are 256 and 1024 respectively. +For example, specifying a minimum size of 384 will not use resolutions such as 256x1024 or 320x768. +If you increase the resolution to 768x768, you may want to specify 1280 as the maximum size. + +When Aspect Ratio Bucketing is enabled, it may be better to prepare regularization images with various resolutions that are similar to the training images. + +(Because the images in one batch are not biased toward training images and regularization images. + +## augmentation --color_aug / --flip_aug +Augmentation is a method of improving model performance by dynamically changing data during learning. Learn while subtly changing the hue with color_aug and flipping left and right with flip_aug. + +Since the data changes dynamically, it cannot be specified together with the cache_latents option. + +## Specify data precision when saving --save_precision +Specifying float, fp16, or bf16 as the save_precision option will save the checkpoint in that format (only when saving in Stable Diffusion format). Please use it when you want to reduce the size of checkpoint. + +## save in any format --save_model_as +Specify the save format of the model. Specify one of ckpt, safetensors, diffusers, diffusers_safetensors. + +When reading Stable Diffusion format (ckpt or safetensors) and saving in Diffusers format, missing information is supplemented by dropping v1.5 or v2.1 information from Hugging Face. + +## Save learning log --logging_dir / --log_prefix +Specify the log save destination folder in the logging_dir option. Logs in TensorBoard format are saved. + +For example, if you specify --logging_dir=logs, a logs folder will be created in your working folder, and logs will be saved in the date/time folder. +Also, if you specify the --log_prefix option, the specified string will be added before the date and time. Use "--logging_dir=logs --log_prefix=db_style1_" for identification. + +To check the log with TensorBoard, open another command prompt and enter the following in the working folder (I think tensorboard is installed when Diffusers is installed, but if it is not installed, pip install Please put it in tensorboard). + +``` +tensorboard --logdir=logs +``` + +Then open your browser and go to http://localhost:6006/ to see it. + +## scheduler related specification of learning rate --lr_scheduler / --lr_warmup_steps +You can choose the learning rate scheduler from linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup with the lr_scheduler option. Default is constant. With lr_warmup_steps, you can specify the number of steps to warm up the scheduler (gradually changing the learning rate). Please do your own research for details. + +## Training with fp16 gradient (experimental feature) --full_fp16 +The full_fp16 option will change the gradient from normal float32 to float16 (fp16) and learn (it seems to be full fp16 learning instead of mixed precision). +As a result, it seems that the SD1.x 512x512 size can be learned with a VRAM usage of less than 8GB, and the SD2.x 512x512 size can be learned with a VRAM usage of less than 12GB. + +Specify fp16 in the accelerate config beforehand and optionally set ``mixed_precision="fp16"`` (bf16 does not work). + +To minimize memory usage, use xformers, use_8bit_adam, cache_latents, gradient_checkpointing options and set train_batch_size to 1. + +(If you can afford it, increasing the train_batch_size step by step should improve the accuracy a little.) + +It is realized by patching the PyTorch source (confirmed with PyTorch 1.12.1 and 1.13.0). Accuracy will drop considerably, and the probability of learning failure on the way will also increase. +The setting of the learning rate and the number of steps seems to be severe. Please be aware of them and use them at your own risk. + +# Other learning methods + +## Learning multiple classes, multiple identifiers +The method is simple, multiple folders with ``Repetition count_ `` in the training image folder, and a folder with ``Repetition count_`` in the regularization image folder. Please prepare multiple + +For example, learning "sls frog" and "cpc rabbit" at the same time would look like this: + +![image](https://user-images.githubusercontent.com/52813779/210777933-a22229db-b219-4cd8-83ca-e87320fc4192.png) + +If you have one class and multiple targets, you can have only one regularized image folder. For example, if 1girl has character A and character B, do as follows. + +- train_girls + - 10_sls 1girl + - 10_cpc 1girl +- reg_girls + -1_1girl + +If the number of data varies, it seems that good results can be obtained by adjusting the number of repetitions to unify the number of sheets for each class and identifier. + +## Use captions in DreamBooth +If you put a file with the same file name as the image and the extension .caption (you can change it in the option) in the training image and regularization image folders, the caption will be read from that file and learned as a prompt. + +* The folder name (identifier class) will no longer be used for training those images. + +Adding captions to each image (you can use BLIP, etc.) may help clarify the attributes you want to learn. + +Caption files have a .caption extension by default. You can change it with --caption_extension. With the --shuffle_caption option, study captions during learning while shuffling each part separated by commas. \ No newline at end of file diff --git a/train_network.py b/train_network.py index cf171a7..9f292b9 100644 --- a/train_network.py +++ b/train_network.py @@ -1,891 +1,17 @@ -import gc import importlib -import json -import time -from typing import NamedTuple -from torch.autograd.function import Function import argparse -import glob +import gc import math import os -import random from tqdm import tqdm import torch -from torchvision import transforms -from accelerate import Accelerator from accelerate.utils import set_seed -from transformers import CLIPTokenizer import diffusers -from diffusers import DDPMScheduler, StableDiffusionPipeline -import albumentations as albu -import numpy as np -from PIL import Image -import cv2 -from einops import rearrange -from torch import einsum +from diffusers import DDPMScheduler -import library.model_util as model_util - -# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = "openai/clip-vit-large-patch14" -V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ - -# checkpointファイル名 -EPOCH_STATE_NAME = "epoch-{:06d}-state" -LAST_STATE_NAME = "last-state" - -EPOCH_FILE_NAME = "epoch-{:06d}" -LAST_FILE_NAME = "last" - - -# region dataset - -class ImageInfo(): - def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: - self.image_key: str = image_key - self.num_repeats: int = num_repeats - self.caption: str = caption - self.is_reg: bool = is_reg - self.absolute_path: str = absolute_path - self.image_size: tuple[int, int] = None - self.bucket_reso: tuple[int, int] = None - self.latents: torch.Tensor = None - self.latents_flipped: torch.Tensor = None - self.latents_npz: str = None - self.latents_npz_flipped: str = None - - -class BucketBatchIndex(NamedTuple): - bucket_index: int - batch_index: int - - -class BaseDataset(torch.utils.data.Dataset): - def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug: bool, color_aug: bool, face_crop_aug_range, debug_dataset: bool) -> None: - super().__init__() - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.width, self.height = resolution - self.face_crop_aug_range = face_crop_aug_range - self.flip_aug = flip_aug - self.color_aug = color_aug - self.debug_dataset = debug_dataset - - self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - - # augmentation - flip_p = 0.5 if flip_aug else 0.0 - if color_aug: - # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hueあたりを触る - self.aug = albu.Compose([ - albu.OneOf([ - albu.HueSaturationValue(8, 0, 0, p=.5), - albu.RandomGamma((95, 105), p=.5), - ], p=.33), - albu.HorizontalFlip(p=flip_p) - ], p=1.) - elif flip_aug: - self.aug = albu.Compose([ - albu.HorizontalFlip(p=flip_p) - ], p=1.) - else: - self.aug = None - - self.image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) - - self.image_data: dict[str, ImageInfo] = {} - - def process_caption(self, caption): - if self.shuffle_caption: - tokens = caption.strip().split(",") - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[:self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens:] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ",".join(tokens).strip() - return caption - - def get_input_ids(self, caption): - input_ids = self.tokenizer(caption, padding="max_length", truncation=True, - max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) - ids_chunk = (input_ids[0].unsqueeze(0), - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): - ids_chunk = (input_ids[0].unsqueeze(0), # BOS - input_ids[i:i + self.tokenizer.model_max_length - 2], - input_ids[-1].unsqueeze(0)) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) - - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id - - iids_list.append(ids_chunk) - - input_ids = torch.stack(iids_list) # 3,77 - return input_ids - - def register_image(self, info: ImageInfo): - self.image_data[info.image_key] = info - - def make_buckets(self, enable_bucket, min_size, max_size): - ''' - bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る) - min_size and max_size are ignored when enable_bucket is False - ''' - - self.enable_bucket = enable_bucket - - print("loading image sizes.") - for info in tqdm(self.image_data.values()): - if info.image_size is None: - info.image_size = self.get_image_size(info.absolute_path) - - if enable_bucket: - print("make buckets") - else: - print("prepare dataset") - - # bucketingを用意する - if enable_bucket: - bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size) - else: - # bucketはひとつだけ、すべての画像は同じ解像度 - bucket_resos = [(self.width, self.height)] - bucket_aspect_ratios = [self.width / self.height] - bucket_aspect_ratios = np.array(bucket_aspect_ratios) - - # bucketを作成する - if enable_bucket: - img_ar_errors = [] - for image_info in self.image_data.values(): - # bucketを決める - image_width, image_height = image_info.image_size - aspect_ratio = image_width / image_height - ar_errors = bucket_aspect_ratios - aspect_ratio - - bucket_id = np.abs(ar_errors).argmin() - image_info.bucket_reso = bucket_resos[bucket_id] - - ar_error = ar_errors[bucket_id] - img_ar_errors.append(ar_error) - else: - reso = (self.width, self.height) - for image_info in self.image_data.values(): - image_info.bucket_reso = reso - - # 画像をbucketに分割する - self.buckets: list[str] = [[] for _ in range(len(bucket_resos))] - reso_to_index = {} - for i, reso in enumerate(bucket_resos): - reso_to_index[reso] = i - - for image_info in self.image_data.values(): - bucket_index = reso_to_index[image_info.bucket_reso] - for _ in range(image_info.num_repeats): - self.buckets[bucket_index].append(image_info.image_key) - - if enable_bucket: - print("number of images (including repeats for DreamBooth) / 各bucketの画像枚数(DreamBoothの場合は繰り返し回数を含む)") - for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): - print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") - img_ar_errors = np.array(img_ar_errors) - print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") - - # 参照用indexを作る - self.buckets_indices: list(BucketBatchIndex) = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append(BucketBatchIndex(bucket_index, batch_index)) - - self.shuffle_buckets() - self._length = len(self.buckets_indices) - - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) - - def load_image(self, image_path): - image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img - - def resize_and_trim(self, image, reso): - image_height, image_width = image.shape[0:2] - ar_img = image_width / image_height - ar_reso = reso[0] / reso[1] - if ar_img > ar_reso: # 横が長い→縦を合わせる - scale = reso[1] / image_height - else: - scale = reso[0] / image_width - resized_size = (int(image_width * scale + .5), int(image_height * scale + .5)) - - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ - if resized_size[0] > reso[0]: - trim_size = resized_size[0] - reso[0] - image = image[:, trim_size//2:trim_size//2 + reso[0]] - elif resized_size[1] > reso[1]: - trim_size = resized_size[1] - reso[1] - image = image[trim_size//2:trim_size//2 + reso[1]] - assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \ - f"internal error, illegal trimmed size: {image.shape}, {reso}" - return image - - def cache_latents(self, vae): - print("caching latents.") - for info in tqdm(self.image_data.values()): - if info.latents_npz is not None: - info.latents = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) - info.latents_flipped = torch.FloatTensor(info.latents_flipped) - continue - - image = self.load_image(info.absolute_path) - image = self.resize_and_trim(image, info.bucket_reso) - - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - if self.flip_aug: - image = image[:, ::-1].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") - - def get_image_size(self, image_path): - image = Image.open(image_path) - return image.size - - def load_image_with_face_info(self, image_path: str): - img = self.load_image(image_path) - - face_cx = face_cy = face_w = face_h = 0 - if self.face_crop_aug_range is not None: - tokens = os.path.splitext(os.path.basename(image_path))[0].split('_') - if len(tokens) >= 5: - face_cx = int(tokens[-4]) - face_cy = int(tokens[-3]) - face_w = int(tokens[-2]) - face_h = int(tokens[-1]) - - return img, face_cx, face_cy, face_w, face_h - - # いい感じに切り出す - def crop_target(self, image, face_cx, face_cy, face_w, face_h): - height, width = image.shape[0:2] - if height == self.height and width == self.width: - return image - - # 画像サイズはsizeより大きいのでリサイズする - face_size = max(face_w, face_h) - min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率) - min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ - max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ - if min_scale >= max_scale: # range指定がmin==max - scale = min_scale - else: - scale = random.uniform(min_scale, max_scale) - - nh = int(height * scale + .5) - nw = int(width * scale + .5) - assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}" - image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) - face_cx = int(face_cx * scale + .5) - face_cy = int(face_cy * scale + .5) - height, width = nh, nw - - # 顔を中心として448*640とかへ切り出す - for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))): - p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置 - - if self.random_crop: - # 背景も含めるために顔を中心に置く確率を高めつつずらす - range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう - p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数 - else: - # range指定があるときのみ、すこしだけランダムに(わりと適当) - if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]: - if face_size > self.size // 10 and face_size >= 40: - p1 = p1 + random.randint(-face_size // 20, +face_size // 20) - - p1 = max(0, min(p1, length - target_size)) - - if axis == 0: - image = image[p1:p1 + target_size, :] - else: - image = image[:, p1:p1 + target_size] - - return image - - def load_latents_from_npz(self, image_info: ImageInfo, flipped): - npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz - return np.load(npz_file)['arr_0'] - - def __len__(self): - return self._length - - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() - - bucket = self.buckets[self.buckets_indices[index].bucket_index] - image_index = self.buckets_indices[index].batch_index * self.batch_size - - loss_weights = [] - captions = [] - input_ids_list = [] - latents_list = [] - images = [] - - for image_key in bucket[image_index:image_index + self.batch_size]: - image_info = self.image_data[image_key] - loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) - - # image/latentsを処理する - if image_info.latents is not None: - latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped - image = None - elif image_info.latents_npz is not None: - latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5) - latents = torch.FloatTensor(latents) - image = None - else: - # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path) - im_h, im_w = img.shape[0:2] - - if self.enable_bucket: - img = self.resize_and_trim(img, image_info.bucket_reso) - else: - if face_cx > 0: # 顔位置情報あり - img = self.crop_target(img, face_cx, face_cy, face_w, face_h) - elif im_h > self.height or im_w > self.width: - assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" - if im_h > self.height: - p = random.randint(0, im_h - self.height) - img = img[p:p + self.height] - if im_w > self.width: - p = random.randint(0, im_w - self.width) - img = img[:, p:p + self.width] - - im_h, im_w = img.shape[0:2] - assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" - - # augmentation - if self.aug is not None: - img = self.aug(image=img)['image'] - - latents = None - image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる - - images.append(image) - latents_list.append(latents) - - caption = self.process_caption(image_info.caption) - captions.append(caption) - input_ids_list.append(self.get_input_ids(caption)) - - example = {} - example['loss_weights'] = torch.FloatTensor(loss_weights) - example['input_ids'] = torch.stack(input_ids_list) - - if images[0] is not None: - images = torch.stack(images) - images = images.to(memory_format=torch.contiguous_format).float() - else: - images = None - example['images'] = images - - example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None - - if self.debug_dataset: - example['image_keys'] = bucket[image_index:image_index + self.batch_size] - example['captions'] = captions - return example - - -class DreamBoothDataset(BaseDataset): - def __init__(self, batch_size, train_data_dir, reg_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) - - self.batch_size = batch_size - self.size = min(self.width, self.height) # 短いほう - self.prior_loss_weight = prior_loss_weight - self.random_crop = random_crop - self.latents_cache = None - self.enable_bucket = False - - def read_caption(img_path): - # captionの候補ファイル名を作る - base_name = os.path.splitext(img_path)[0] - base_name_face_det = base_name - tokens = base_name.split("_") - if len(tokens) >= 5: - base_name_face_det = "_".join(tokens[:-4]) - cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] - - caption = None - for cap_path in cap_paths: - if os.path.isfile(cap_path): - with open(cap_path, "rt", encoding='utf-8') as f: - lines = f.readlines() - assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" - caption = lines[0].strip() - break - return caption - - def load_dreambooth_dir(dir): - if not os.path.isdir(dir): - # print(f"ignore file: {dir}") - return 0, [], [] - - tokens = os.path.basename(dir).split('_') - try: - n_repeats = int(tokens[0]) - except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") - return 0, [], [] - - caption_by_folder = '_'.join(tokens[1:]) - img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \ - glob.glob(os.path.join(dir, "*.webp")) - print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") - - # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う - captions = [] - for img_path in img_paths: - cap_for_img = read_caption(img_path) - captions.append(caption_by_folder if cap_for_img is None else cap_for_img) - - return n_repeats, img_paths, captions - - print("prepare train images.") - train_dirs = os.listdir(train_data_dir) - num_train_images = 0 - for dir in train_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) - num_train_images += n_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, False, img_path) - self.register_image(info) - print(f"{num_train_images} train images with repeating.") - self.num_train_images = num_train_images - - # reg imageは数を数えて学習画像と同じ枚数にする - num_reg_images = 0 - if reg_data_dir: - print("prepare reg images.") - reg_infos: list[ImageInfo] = [] - - reg_dirs = os.listdir(reg_data_dir) - for dir in reg_dirs: - n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(reg_data_dir, dir)) - num_reg_images += n_repeats * len(img_paths) - for img_path, caption in zip(img_paths, captions): - info = ImageInfo(img_path, n_repeats, caption, True, img_path) - reg_infos.append(info) - - print(f"{num_reg_images} reg images.") - if num_train_images < num_reg_images: - print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります") - - if num_reg_images == 0: - print("no regularization images / 正則化画像が見つかりませんでした") - else: - n = 0 - while n < num_train_images: - for info in reg_infos: - self.register_image(info) - n += info.num_repeats - if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない - break - - self.num_reg_images = num_reg_images - - -class FineTuningDataset(BaseDataset): - def __init__(self, metadata, batch_size, train_data_dir, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, dataset_repeats, debug_dataset) -> None: - super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, - resolution, flip_aug, color_aug, face_crop_aug_range, debug_dataset) - - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - - for image_key, img_md in metadata.items(): - # path情報を作る - if os.path.exists(image_key): - abs_path = image_key - else: - # わりといい加減だがいい方法が思いつかん - abs_path = (glob.glob(os.path.join(train_data_dir, f"{image_key}.png")) + glob.glob(os.path.join(train_data_dir, f"{image_key}.jpg")) + - glob.glob(os.path.join(train_data_dir, f"{image_key}.webp"))) - assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}" - abs_path = abs_path[0] - - caption = img_md.get('caption') - tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}" - - image_info = ImageInfo(image_key, dataset_repeats, caption, False, abs_path) - image_info.image_size = img_md.get('train_resolution') - - if not self.color_aug: - # if npz exists, use them - image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(image_key) - - self.register_image(image_info) - self.num_train_images = len(metadata) * dataset_repeats - self.num_reg_images = 0 - - # check existence of all npz files - if not self.color_aug: - npz_any = False - npz_all = True - for image_info in self.image_data.values(): - has_npz = image_info.latents_npz is not None - npz_any = npz_any or has_npz - - if self.flip_aug: - has_npz = has_npz and image_info.latents_npz_flipped is not None - npz_all = npz_all and has_npz - - if npz_any and not npz_all: - break - - if not npz_any: - print(f"npz file does not exist. make latents with VAE / npzファイルが見つからないためVAEを使ってlatentsを取得します") - elif not npz_all: - print(f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します") - for image_info in self.image_data.values(): - image_info.latents_npz = image_info.latents_npz_flipped = None - - # check min/max bucket size - sizes = set() - for image_info in self.image_data.values(): - if image_info.image_size is None: - sizes = None # not calculated - break - sizes.add(image_info.image_size[0]) - sizes.add(image_info.image_size[1]) - - if sizes is None: - self.min_bucket_reso = self.max_bucket_reso = None # set as not calculated - else: - self.min_bucket_reso = min(sizes) - self.max_bucket_reso = max(sizes) - - def image_key_to_npz_file(self, image_key): - base_name = os.path.splitext(image_key)[0] - npz_file_norm = base_name + '.npz' - - if os.path.exists(npz_file_norm): - # image_key is full path - npz_file_flip = base_name + '_flip.npz' - if not os.path.exists(npz_file_flip): - npz_file_flip = None - return npz_file_norm, npz_file_flip - - # image_key is relative path - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') - - if not os.path.exists(npz_file_norm): - npz_file_norm = None - npz_file_flip = None - elif not os.path.exists(npz_file_flip): - npz_file_flip = None - - return npz_file_norm, npz_file_flip - -# endregion - - -# region モジュール入れ替え部 -""" -高速化のためのモジュール入れ替え -""" - -# FlashAttentionを使うCrossAttention -# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py -# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE - -# constants - -EPSILON = 1e-6 - -# helper functions - - -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - -# flash attention forwards and backwards - -# https://arxiv.org/abs/2205.14135 - - -class FlashAttentionFunction(Function): - @ staticmethod - @ torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """ Algorithm 2 in the paper """ - - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - o = torch.zeros_like(q) - all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) - all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - - scale = (q.shape[-1] ** -0.5) - - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) - - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) - - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.) - - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - - exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - - new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) - - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - - return o - - @ staticmethod - @ torch.no_grad() - def backward(ctx, do): - """ Algorithm 4 in the paper """ - - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors - - device = q.device - - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) - - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2) - ) - - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff - - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) - - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size - - attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - - if causal and q_start_index < (k_start_index + k_bucket_size - 1): - causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, - device=device).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) - - exp_attn_weights = torch.exp(attn_weights - mc) - - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.) - - p = exp_attn_weights / lc - - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) - - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) - - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) - - return dq, dk, dv, None, None, None, None - - -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() - - -def replace_unet_cross_attn_to_memory_efficient(): - print("Replace CrossAttention.forward to use FlashAttention") - flash_func = FlashAttentionFunction - - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 - - h = self.heads - q = self.to_q(x) - - context = context if context is not None else x - context = context.to(x.dtype) - k = self.to_k(context) - v = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - - out = rearrange(out, 'b h n d -> b n (h d)') - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_flash_attn - - -def replace_unet_cross_attn_to_xformers(): - print("Replace CrossAttention.forward to use xformers") - try: - import xformers.ops - except ImportError: - raise ImportError("No xformers / xformersがインストールされていないようです") - - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) - - context = default(context, x) - context = context.to(x.dtype) - - k_in = self.to_k(context) - v_in = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format - # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format - del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - # out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers -# endregion +import library.train_util as train_util +from library.train_util import DreamBoothDataset, FineTuningDataset def collate_fn(examples): @@ -893,177 +19,52 @@ def collate_fn(examples): def train(args): + train_util.verify_training_args(args) + train_util.prepare_dataset_args(args, True) + cache_latents = args.cache_latents - - # latentsをキャッシュする場合のオプション設定を確認する - if cache_latents: - assert not args.color_aug, "when caching latents, color_aug cannot be used / latentをキャッシュするときはcolor_augは使えません" - - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") - if args.v2 and args.clip_skip is not None: - print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - use_dreambooth_method = args.in_json is None - # モデル形式のオプション設定を確認する: - load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - - # 乱数系列を初期化する if args.seed is not None: set_seed(args.seed) - # tokenizerを読み込む - print("prepare tokenizer") - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - - if args.max_token_length is not None: - print(f"update token length: {args.max_token_length}") - - # 学習データを用意する - assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" - resolution = tuple([int(r) for r in args.resolution.split(',')]) - if len(resolution) == 1: - resolution = (resolution[0], resolution[0]) - assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" - - if args.face_crop_aug_range is not None: - face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) - assert len( - face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}" - else: - face_crop_aug_range = None + tokenizer = train_util.load_tokenizer(args) # データセットを準備する if use_dreambooth_method: print("Use DreamBooth method.") train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, - resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.debug_dataset) + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) else: print("Train with captions.") - - # メタデータを読み込む - if os.path.exists(args.in_json): - print(f"loading existing metadata: {args.in_json}") - with open(args.in_json, "rt", encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f"no metadata / メタデータファイルがありません: {args.in_json}") - return - - if args.color_aug: - print(f"latents in npz is ignored when color_aug is True / color_augを有効にした場合、npzファイルのlatentsは無視されます") - - train_dataset = FineTuningDataset(metadata, args.train_batch_size, args.train_data_dir, + train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, - resolution, args.flip_aug, args.color_aug, face_crop_aug_range, args.dataset_repeats, args.debug_dataset) - - if train_dataset.min_bucket_reso is not None and (args.enable_bucket or train_dataset.min_bucket_reso != train_dataset.max_bucket_reso): - print(f"using bucket info in metadata / メタデータ内のbucket情報を使います") - args.min_bucket_reso = train_dataset.min_bucket_reso - args.max_bucket_reso = train_dataset.max_bucket_reso - args.enable_bucket = True - print(f"min bucket reso: {args.min_bucket_reso}, max bucket reso: {args.max_bucket_reso}") - - if args.enable_bucket: - assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" - - train_dataset.make_buckets(args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso) + args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, + args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, + args.dataset_repeats, args.debug_dataset) + train_dataset.make_buckets() if args.debug_dataset: - print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") - print("Escape for exit. / Escキーで中断、終了します") - k = 0 - for example in train_dataset: - if example['latents'] is not None: - print("sample has latents from npz file") - for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])): - print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}') - if example['images'] is not None: - im = example['images'][j] - im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8) - im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c - im = im[:, :, ::-1] # RGB -> BGR (OpenCV) - cv2.imshow("img", im) - k = cv2.waitKey() - cv2.destroyAllWindows() - if k == 27: - break - if k == 27 or example['images'] is None: - break + train_util.debug_dataset(train_dataset) return - if len(train_dataset) == 0: print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") return # acceleratorを準備する print("prepare accelerator") - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = "tensorboard" - log_prefix = "" if args.log_prefix is None else args.log_prefix - logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) - - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, - log_with=log_with, logging_dir=logging_dir) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True - try: - accelerator.unwrap_model("dummy", True) - print("Using accelerator 0.15.0 or above.") - except TypeError: - accelerator_0_15 = False - - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + accelerator, unwrap_model = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif args.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - save_dtype = None - if args.save_precision == "fp16": - save_dtype = torch.float16 - elif args.save_precision == "bf16": - save_dtype = torch.bfloat16 - elif args.save_precision == "float": - save_dtype = torch.float32 + weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む - if load_stable_diffusion_format: - print("load StableDiffusion checkpoint") - text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) - else: - print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - del pipe - - # VAEを読み込む - if args.vae is not None: - vae = model_util.load_vae(args.vae, weight_dtype) - print("additional VAE loaded") + text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) # モデルに xformers とか memory efficient attention を組み込む - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) # 学習を準備する if cache_latents: @@ -1131,14 +132,12 @@ def train(args): # lr schedulerを用意する lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") - # unet.to(weight_dtype) - # text_encoder.to(weight_dtype) network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい @@ -1157,10 +156,14 @@ def train(args): unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) - unet.eval() text_encoder.requires_grad_(False) text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.eval() + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + unet.train() + text_encoder.train() + else: + unet.eval() + text_encoder.eval() network.prepare_grad_etc(text_encoder, unet) @@ -1171,12 +174,7 @@ def train(args): # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ - - def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) - - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: @@ -1211,17 +209,16 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - # 指定したステップ数までText Encoderを学習する:epoch最初の状態 network.on_epoch_start(text_encoder, unet) loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): - # latentに変換 - if batch["latents"] is not None: + if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) else: + # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 b_size = latents.shape[0] @@ -1229,39 +226,7 @@ def train(args): with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) - input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] - encoder_hidden_states = encoder_hidden_states.to(weight_dtype) # なぜかこれが必要 - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, args.max_token_length, tokenizer.model_max_length): - states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) + encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) @@ -1279,7 +244,6 @@ def train(args): if args.v_parameterization: # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise @@ -1326,15 +290,26 @@ def train(args): accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: - print("saving checkpoint.") - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, EPOCH_FILE_NAME.format(epoch + 1) + '.' + args.save_model_as) + model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name + + def save_func(): + ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype) - if args.save_state: - print("saving state.") - accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + def remove_old_func(old_epoch_no): + old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as + old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) + if os.path.exists(old_ckpt_file): + print(f"removing old checkpoint: {old_ckpt_file}") + os.remove(old_ckpt_file) + + saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) + if saving and args.save_state: + train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) + + # end of epoch is_main_process = accelerator.is_main_process if is_main_process: @@ -1343,103 +318,37 @@ def train(args): accelerator.end_training() if args.save_state: - print("saving last state.") - os.makedirs(args.output_dir, exist_ok=True) - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join(args.output_dir, LAST_FILE_NAME + '.' + args.save_model_as) + + model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name + ckpt_name = model_name + '.' + args.save_model_as + ckpt_file = os.path.join(args.output_dir, ckpt_name) + print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype) print("model saved.") if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) parser = argparse.ArgumentParser() - parser.add_argument("--v2", action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') - parser.add_argument("--v_parameterization", action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする') - parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, - help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") - parser.add_argument("--network_weights", type=str, default=None, - help="pretrained weights for network / 学習するネットワークの初期重み") - parser.add_argument("--shuffle_caption", action="store_true", - help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする") - parser.add_argument("--keep_tokens", type=int, default=None, - help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") - parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") - parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--dataset_repeats", type=int, default=1, - help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") - parser.add_argument("--output_dir", type=str, default=None, - help="directory to output trained model / 学習後のモデル出力先ディレクトリ") - parser.add_argument("--save_precision", type=str, default=None, - choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") + + train_util.add_sd_models_arguments(parser) + train_util.add_dataset_arguments(parser, True, True) + train_util.add_training_arguments(parser, True) + parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)") - parser.add_argument("--save_every_n_epochs", type=int, default=None, - help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") - parser.add_argument("--save_state", action="store_true", - help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") - parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") - parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする") - parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする") - parser.add_argument("--face_crop_aug_range", type=str, default=None, - help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)") - parser.add_argument("--random_crop", action="store_true", - help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)") - parser.add_argument("--debug_dataset", action="store_true", - help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") - parser.add_argument("--resolution", type=str, default=None, - help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)") - parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ") - parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], - help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") - parser.add_argument("--use_8bit_adam", action="store_true", - help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") - parser.add_argument("--mem_eff_attn", action="store_true", - help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") - parser.add_argument("--xformers", action="store_true", - help="use xformers for CrossAttention / CrossAttentionにxformersを使う") - parser.add_argument("--vae", type=str, default=None, - help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ") - parser.add_argument("--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)") - parser.add_argument("--enable_bucket", action="store_true", - help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする") - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") - parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") - parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") - parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み") - # parser.add_argument("--stop_text_encoder_training", type=int, default=None, - # help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数") - parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") - parser.add_argument("--gradient_checkpointing", action="store_true", - help="enable gradient checkpointing / grandient checkpointingを有効にする") - parser.add_argument("--gradient_accumulation_steps", type=int, default=1, - help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") - parser.add_argument("--mixed_precision", type=str, default="no", - choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") - parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") - parser.add_argument("--clip_skip", type=int, default=None, - help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") - parser.add_argument("--logging_dir", type=str, default=None, - help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") - parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") - parser.add_argument("--lr_scheduler", type=str, default="constant", - help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") - parser.add_argument("--lr_warmup_steps", type=int, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") + + parser.add_argument("--network_weights", type=str, default=None, + help="pretrained weights for network / 学習するネットワークの初期重み") parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') @@ -1450,4 +359,4 @@ if __name__ == '__main__': help="only training Text Encoder part / Text Encoder関連部分のみ学習する") args = parser.parse_args() - train(args) \ No newline at end of file + train(args) diff --git a/README_train_network-ja.md b/train_network_README-ja.md similarity index 99% rename from README_train_network-ja.md rename to train_network_README-ja.md index 1ad1b7a..77ef4c1 100644 --- a/README_train_network-ja.md +++ b/train_network_README-ja.md @@ -186,4 +186,4 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA ### 将来拡張について -LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 \ No newline at end of file +LoRAだけでなく他の拡張にも対応可能ですので、それらも追加予定です。 diff --git a/README_train_network.md b/train_network_README.md similarity index 70% rename from README_train_network.md rename to train_network_README.md index d194cd0..b0363a6 100644 --- a/README_train_network.md +++ b/train_network_README.md @@ -1,35 +1,32 @@ -# Train network documentation translated from japanese ## About learning LoRA [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) (arxiv), [LoRA](https://github.com/microsoft/LoRA) (github) to Stable Applied to Diffusion. -[cloneofsimo's repository](https://github.com/cloneofsimo/lora) was a great reference. thank you. +[cloneofsimo's repository](https://github.com/cloneofsimo/lora) was a great reference. Thank you very much. 8GB VRAM seems to work just fine. ## A Note about Trained Models -Cloneofsimo's repository and d8ahazard's [Drebooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_drebooth_extension) are currently incompatible due to ongoing enhancements (see below). +Cloneofsimo's repository and d8ahazard's [Drebooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_drebooth_extension) are currently incompatible. Because we are doing some enhancements (see below). -In order to generate images using WebUI, it is necessary to merge the learned LoRA model with the Stable Diffusion model using the script in this repository. The resulting merged model file will incorporate the learning results from LoRA. Note that merging is not required when generating images with the script in this repository. - -Note that merging is not required when generating with the image generation script in this repository. +When generating images with WebUI, etc., merge the learned LoRA model with the learning source Stable Diffusion model in advance with the script in this repository, or click here [Extention for WebUI] (https://github .com/kohya-ss/sd-webui-additional-networks). ## Learning method Use train_network.py. -You can learn both the DreamBooth method (using identifiers (sks, etc.) and classes, optionally with regularized images) and the fine tuning method using captions. +You can learn both the DreamBooth method (using identifiers (sks, etc.) and classes, optionally regularized images) and the fine tuning method using captions. Both methods can be learned in much the same way as existing scripts. We will discuss the differences later. ### Using the DreamBooth Method -Please refer to note.com [Environment preparation and DreamBooth learning script](https://note.com/kohya_ss/n/nba4eceaa4594) to prepare the data. +Please refer to [DreamBooth guide](./train_db_README-en.md) and prepare the data. Specify train_network.py instead of train_db.py when training. -Almost all options are available (except model saving related to Stable Diffusion), but stop_text_encoder_training is not supported. +Almost all options are available (except Stable Diffusion model save related), but stop_text_encoder_training is not supported. ### When to use captions @@ -75,7 +72,9 @@ In addition, the following options can be specified. * --text_encoder_lr * Specify when using a learning rate different from the normal learning rate (specified with the --learning_rate option) for the LoRA module associated with the Text Encoder. Some people say that it is better to set the Text Encoder to a slightly lower learning rate (such as 5e-5). -If both --network_train_unet_only and --network_train_text_encoder_only are not specified (default), both Text Encoder and U-Net LoRA modules will be enabled. ## About the merge script +When neither --network_train_unet_only nor --network_train_text_encoder_only is specified (default), both Text Encoder and U-Net LoRA modules are enabled. + +## About the merge script merge_lora.py allows you to merge LoRA training results into a Stable Diffusion model, or merge multiple LoRA models. @@ -109,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt ### Merge multiple LoRA models -After all, it may not be very useful because it cannot be inferred unless it is merged into the SD model. However, when merging multiple LoRA models one by one into the SD model, and when merging multiple LoRA models and then merging them into the SD model, the result will be slightly different in relation to the calculation order. +Applying multiple LoRA models one by one to the SD model and merging multiple LoRA models and then merging them into the SD model yield slightly different results in relation to the calculation order. For example, a command line like: @@ -143,14 +142,48 @@ Add options --network_module, --network_weights, --network_dim (optional) to gen You can change the LoRA application rate by specifying a value between 0 and 1.0 with the --network_mul option. +## Create a LoRA model from the difference between two models + +It was implemented with reference to [this discussion](https://github.com/cloneofsimo/lora/discussions/56). I used the formula as it is (I don't understand it well, but it seems that singular value decomposition is used for approximation). + +LoRA approximates the difference between two models (for example, the original model after fine tuning and the model after fine tuning). + +### How to run scripts + +Please specify as follows. +``` +python networks\extract_lora_from_models.py --model_org base-model.ckpt + --model_tuned fine-tuned-model.ckpt + --save_to lora-weights.safetensors --dim 4 +``` + +Specify the original Stable Diffusion model for the --model_org option. When applying the created LoRA model, this model will be specified and applied. .ckpt or .safetensors can be specified. + +Specify the Stable Diffusion model to extract the difference in the --model_tuned option. For example, specify a model after fine tuning or DreamBooth. .ckpt or .safetensors can be specified. + +Specify the save destination of the LoRA model in --save_to. Specify the number of dimensions of LoRA in --dim. + +A generated LoRA model can be used in the same way as a trained LoRA model. + +If the Text Encoder is the same for both models, LoRA will be U-Net only LoRA. + +### Other Options + +--v2 + - Please specify when using the v2.x Stable Diffusion model. +--device + - If cuda is specified as ``--device cuda``, the calculation will be performed on the GPU. Processing will be faster (because even the CPU is not that slow, it seems to be at most twice or several times faster). +--save_precision + - Specify the LoRA save format from "float", "fp16", "bf16". Default is float. + ## Additional Information ### Differences from cloneofsimo's repository -As of 12/25, this repository has expanded LoRA application points to Text Encoder's MLP, U-Net's FFN, and Transformer's in/out projection, increasing its expressiveness. However, the amount of memory used increased, and it became the last minute of 8GB instead. +As of 12/25, this repository has expanded LoRA application points to Text Encoder's MLP, U-Net's FFN, and Transformer's in/out projection, increasing its expressiveness. However, the amount of memory used increased instead, and it became the last minute of 8GB. Also, the module replacement mechanism is completely different. -### About future expansion +### About Future Expansion It is possible to support not only LoRA but also other expansions, so we plan to add them as well. \ No newline at end of file diff --git a/upgrade.ps1 b/upgrade.ps1 new file mode 100644 index 0000000..cf46b68 --- /dev/null +++ b/upgrade.ps1 @@ -0,0 +1,3 @@ +git pull +.\venv\Scripts\activate +pip install --upgrade -r requirements.txt \ No newline at end of file