diff --git a/bitsandbytes_windows/main.py b/bitsandbytes_windows/main.py index 71967a1..20a93fa 100644 --- a/bitsandbytes_windows/main.py +++ b/bitsandbytes_windows/main.py @@ -4,7 +4,7 @@ extract factors the build is dependent on: [ ] TODO: Q - What if we have multiple GPUs of different makes? - CUDA version - Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) + - CPU-only: only CPU quantization functions (no optimizer, no matrix multiple) - CuBLAS-LT: full-build 8-bit optimizer - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) @@ -44,7 +44,7 @@ def get_cuda_version(cuda, cudart_path): minor = (version-(major*1000))//10 if major < 11: - print('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + print('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') return f'{major}{minor}' @@ -163,4 +163,4 @@ def evaluate_cuda_setup(): binary_name = get_binary_name() - return binary_name + return binary_name \ No newline at end of file diff --git a/fine_tune.py b/fine_tune.py index b6a0605..39fc15c 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -53,944 +53,668 @@ from torch import einsum import library.model_util as model_util # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う -TOKENIZER_PATH = 'openai/clip-vit-large-patch14' -V2_STABLE_DIFFUSION_PATH = 'stabilityai/stable-diffusion-2' # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ +TOKENIZER_PATH = "openai/clip-vit-large-patch14" +V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ # checkpointファイル名 -EPOCH_STATE_NAME = 'epoch-{:06d}-state' -LAST_STATE_NAME = 'last-state' +EPOCH_STATE_NAME = "epoch-{:06d}-state" +LAST_STATE_NAME = "last-state" -LAST_DIFFUSERS_DIR_NAME = 'last' -EPOCH_DIFFUSERS_DIR_NAME = 'epoch-{:06d}' +LAST_DIFFUSERS_DIR_NAME = "last" +EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}" def collate_fn(examples): - return examples[0] + return examples[0] class FineTuningDataset(torch.utils.data.Dataset): - def __init__( - self, - metadata, - train_data_dir, - batch_size, - tokenizer, - max_token_length, - shuffle_caption, - shuffle_keep_tokens, - dataset_repeats, - debug, - ) -> None: - super().__init__() + def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, dataset_repeats, debug) -> None: + super().__init__() - self.metadata = metadata - self.train_data_dir = train_data_dir - self.batch_size = batch_size - self.tokenizer: CLIPTokenizer = tokenizer - self.max_token_length = max_token_length - self.shuffle_caption = shuffle_caption - self.shuffle_keep_tokens = shuffle_keep_tokens - self.debug = debug + self.metadata = metadata + self.train_data_dir = train_data_dir + self.batch_size = batch_size + self.tokenizer: CLIPTokenizer = tokenizer + self.max_token_length = max_token_length + self.shuffle_caption = shuffle_caption + self.shuffle_keep_tokens = shuffle_keep_tokens + self.debug = debug - self.tokenizer_max_length = ( - self.tokenizer.model_max_length - if max_token_length is None - else max_token_length + 2 - ) + self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 - print('make buckets') + print("make buckets") - # 最初に数を数える - self.bucket_resos = set() - for img_md in metadata.values(): - if 'train_resolution' in img_md: - self.bucket_resos.add(tuple(img_md['train_resolution'])) - self.bucket_resos = list(self.bucket_resos) - self.bucket_resos.sort() - print(f'number of buckets: {len(self.bucket_resos)}') + # 最初に数を数える + self.bucket_resos = set() + for img_md in metadata.values(): + if 'train_resolution' in img_md: + self.bucket_resos.add(tuple(img_md['train_resolution'])) + self.bucket_resos = list(self.bucket_resos) + self.bucket_resos.sort() + print(f"number of buckets: {len(self.bucket_resos)}") - reso_to_index = {} - for i, reso in enumerate(self.bucket_resos): - reso_to_index[reso] = i + reso_to_index = {} + for i, reso in enumerate(self.bucket_resos): + reso_to_index[reso] = i - # bucketに割り当てていく - self.buckets = [[] for _ in range(len(self.bucket_resos))] - n = 1 if dataset_repeats is None else dataset_repeats - images_count = 0 - for image_key, img_md in metadata.items(): - if 'train_resolution' not in img_md: - continue - if not os.path.exists(self.image_key_to_npz_file(image_key)): - continue + # bucketに割り当てていく + self.buckets = [[] for _ in range(len(self.bucket_resos))] + n = 1 if dataset_repeats is None else dataset_repeats + images_count = 0 + for image_key, img_md in metadata.items(): + if 'train_resolution' not in img_md: + continue + if not os.path.exists(self.image_key_to_npz_file(image_key)): + continue - reso = tuple(img_md['train_resolution']) - for _ in range(n): - self.buckets[reso_to_index[reso]].append(image_key) - images_count += n + reso = tuple(img_md['train_resolution']) + for _ in range(n): + self.buckets[reso_to_index[reso]].append(image_key) + images_count += n - # 参照用indexを作る - self.buckets_indices = [] - for bucket_index, bucket in enumerate(self.buckets): - batch_count = int(math.ceil(len(bucket) / self.batch_size)) - for batch_index in range(batch_count): - self.buckets_indices.append((bucket_index, batch_index)) + # 参照用indexを作る + self.buckets_indices = [] + for bucket_index, bucket in enumerate(self.buckets): + batch_count = int(math.ceil(len(bucket) / self.batch_size)) + for batch_index in range(batch_count): + self.buckets_indices.append((bucket_index, batch_index)) - self.shuffle_buckets() - self._length = len(self.buckets_indices) - self.images_count = images_count + self.shuffle_buckets() + self._length = len(self.buckets_indices) + self.images_count = images_count - def show_buckets(self): - for i, (reso, bucket) in enumerate( - zip(self.bucket_resos, self.buckets) - ): - print(f'bucket {i}: resolution {reso}, count: {len(bucket)}') + def show_buckets(self): + for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)): + print(f"bucket {i}: resolution {reso}, count: {len(bucket)}") - def shuffle_buckets(self): - random.shuffle(self.buckets_indices) - for bucket in self.buckets: - random.shuffle(bucket) + def shuffle_buckets(self): + random.shuffle(self.buckets_indices) + for bucket in self.buckets: + random.shuffle(bucket) - def image_key_to_npz_file(self, image_key): - npz_file_norm = os.path.splitext(image_key)[0] + '.npz' - if os.path.exists(npz_file_norm): - if random.random() < 0.5: - npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz' - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm + def image_key_to_npz_file(self, image_key): + npz_file_norm = os.path.splitext(image_key)[0] + '.npz' + if os.path.exists(npz_file_norm): + if random.random() < .5: + npz_file_flip = os.path.splitext(image_key)[0] + '_flip.npz' + if os.path.exists(npz_file_flip): + return npz_file_flip + return npz_file_norm - npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') - if random.random() < 0.5: - npz_file_flip = os.path.join( - self.train_data_dir, image_key + '_flip.npz' - ) - if os.path.exists(npz_file_flip): - return npz_file_flip - return npz_file_norm + npz_file_norm = os.path.join(self.train_data_dir, image_key + '.npz') + if random.random() < .5: + npz_file_flip = os.path.join(self.train_data_dir, image_key + '_flip.npz') + if os.path.exists(npz_file_flip): + return npz_file_flip + return npz_file_norm - def load_latent(self, image_key): - return np.load(self.image_key_to_npz_file(image_key))['arr_0'] + def load_latent(self, image_key): + return np.load(self.image_key_to_npz_file(image_key))['arr_0'] - def __len__(self): - return self._length + def __len__(self): + return self._length - def __getitem__(self, index): - if index == 0: - self.shuffle_buckets() + def __getitem__(self, index): + if index == 0: + self.shuffle_buckets() - bucket = self.buckets[self.buckets_indices[index][0]] - image_index = self.buckets_indices[index][1] * self.batch_size + bucket = self.buckets[self.buckets_indices[index][0]] + image_index = self.buckets_indices[index][1] * self.batch_size - input_ids_list = [] - latents_list = [] - captions = [] - for image_key in bucket[image_index : image_index + self.batch_size]: - img_md = self.metadata[image_key] - caption = img_md.get('caption') - tags = img_md.get('tags') + input_ids_list = [] + latents_list = [] + captions = [] + for image_key in bucket[image_index:image_index + self.batch_size]: + img_md = self.metadata[image_key] + caption = img_md.get('caption') + tags = img_md.get('tags') - if caption is None: - caption = tags - elif tags is not None and len(tags) > 0: - caption = caption + ', ' + tags - assert ( - caption is not None and len(caption) > 0 - ), f'caption or tag is required / キャプションまたはタグは必須です:{image_key}' + if caption is None: + caption = tags + elif tags is not None and len(tags) > 0: + caption = caption + ', ' + tags + assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}" - latents = self.load_latent(image_key) + latents = self.load_latent(image_key) - if self.shuffle_caption: - tokens = caption.strip().split(',') - if self.shuffle_keep_tokens is None: - random.shuffle(tokens) - else: - if len(tokens) > self.shuffle_keep_tokens: - keep_tokens = tokens[: self.shuffle_keep_tokens] - tokens = tokens[self.shuffle_keep_tokens :] - random.shuffle(tokens) - tokens = keep_tokens + tokens - caption = ','.join(tokens).strip() + if self.shuffle_caption: + tokens = caption.strip().split(",") + if self.shuffle_keep_tokens is None: + random.shuffle(tokens) + else: + if len(tokens) > self.shuffle_keep_tokens: + keep_tokens = tokens[:self.shuffle_keep_tokens] + tokens = tokens[self.shuffle_keep_tokens:] + random.shuffle(tokens) + tokens = keep_tokens + tokens + caption = ",".join(tokens).strip() - captions.append(caption) + captions.append(caption) - input_ids = self.tokenizer( - caption, - padding='max_length', - truncation=True, - max_length=self.tokenizer_max_length, - return_tensors='pt', - ).input_ids + input_ids = self.tokenizer(caption, padding="max_length", truncation=True, + max_length=self.tokenizer_max_length, return_tensors="pt").input_ids - if self.tokenizer_max_length > self.tokenizer.model_max_length: - input_ids = input_ids.squeeze(0) - iids_list = [] - if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: - # v1 - # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する - # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に - for i in range( - 1, - self.tokenizer_max_length - - self.tokenizer.model_max_length - + 2, - self.tokenizer.model_max_length - 2, - ): # (1, 152, 75) - ids_chunk = ( - input_ids[0].unsqueeze(0), - input_ids[ - i : i + self.tokenizer.model_max_length - 2 - ], - input_ids[-1].unsqueeze(0), - ) - ids_chunk = torch.cat(ids_chunk) - iids_list.append(ids_chunk) - else: - # v2 - # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する - for i in range( - 1, - self.tokenizer_max_length - - self.tokenizer.model_max_length - + 2, - self.tokenizer.model_max_length - 2, - ): - ids_chunk = ( - input_ids[0].unsqueeze(0), # BOS - input_ids[ - i : i + self.tokenizer.model_max_length - 2 - ], - input_ids[-1].unsqueeze(0), - ) # PAD or EOS - ids_chunk = torch.cat(ids_chunk) + if self.tokenizer_max_length > self.tokenizer.model_max_length: + input_ids = input_ids.squeeze(0) + iids_list = [] + if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + # v1 + # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する + # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75) + ids_chunk = (input_ids[0].unsqueeze(0), + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) + ids_chunk = torch.cat(ids_chunk) + iids_list.append(ids_chunk) + else: + # v2 + # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する + for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): + ids_chunk = (input_ids[0].unsqueeze(0), # BOS + input_ids[i:i + self.tokenizer.model_max_length - 2], + input_ids[-1].unsqueeze(0)) # PAD or EOS + ids_chunk = torch.cat(ids_chunk) - # 末尾が または の場合は、何もしなくてよい - # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) - if ( - ids_chunk[-2] != self.tokenizer.eos_token_id - and ids_chunk[-2] != self.tokenizer.pad_token_id - ): - ids_chunk[-1] = self.tokenizer.eos_token_id - # 先頭が ... の場合は ... に変える - if ids_chunk[1] == self.tokenizer.pad_token_id: - ids_chunk[1] = self.tokenizer.eos_token_id + # 末尾が または の場合は、何もしなくてよい + # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし) + if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id: + ids_chunk[-1] = self.tokenizer.eos_token_id + # 先頭が ... の場合は ... に変える + if ids_chunk[1] == self.tokenizer.pad_token_id: + ids_chunk[1] = self.tokenizer.eos_token_id - iids_list.append(ids_chunk) + iids_list.append(ids_chunk) - input_ids = torch.stack(iids_list) # 3,77 + input_ids = torch.stack(iids_list) # 3,77 - input_ids_list.append(input_ids) - latents_list.append(torch.FloatTensor(latents)) + input_ids_list.append(input_ids) + latents_list.append(torch.FloatTensor(latents)) - example = {} - example['input_ids'] = torch.stack(input_ids_list) - example['latents'] = torch.stack(latents_list) - if self.debug: - example['image_keys'] = bucket[ - image_index : image_index + self.batch_size - ] - example['captions'] = captions - return example + example = {} + example['input_ids'] = torch.stack(input_ids_list) + example['latents'] = torch.stack(latents_list) + if self.debug: + example['image_keys'] = bucket[image_index:image_index + self.batch_size] + example['captions'] = captions + return example def save_hypernetwork(output_file, hypernetwork): - state_dict = hypernetwork.get_state_dict() - torch.save(state_dict, output_file) + state_dict = hypernetwork.get_state_dict() + torch.save(state_dict, output_file) def train(args): - fine_tuning = ( - args.hypernetwork_module is None - ) # fine tuning or hypernetwork training + fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training - # その他のオプション設定を確認する - if args.v_parameterization and not args.v2: - print( - 'v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません' - ) - if args.v2 and args.clip_skip is not None: - print( - 'v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません' - ) + # その他のオプション設定を確認する + if args.v_parameterization and not args.v2: + print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") + if args.v2 and args.clip_skip is not None: + print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") - # モデル形式のオプション設定を確認する - load_stable_diffusion_format = os.path.isfile( - args.pretrained_model_name_or_path - ) + # モデル形式のオプション設定を確認する + load_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path) - if load_stable_diffusion_format: - src_stable_diffusion_ckpt = args.pretrained_model_name_or_path - src_diffusers_model_path = None + if load_stable_diffusion_format: + src_stable_diffusion_ckpt = args.pretrained_model_name_or_path + src_diffusers_model_path = None + else: + src_stable_diffusion_ckpt = None + src_diffusers_model_path = args.pretrained_model_name_or_path + + if args.save_model_as is None: + save_stable_diffusion_format = load_stable_diffusion_format + use_safetensors = args.use_safetensors + else: + save_stable_diffusion_format = args.save_model_as.lower() == 'ckpt' or args.save_model_as.lower() == 'safetensors' + use_safetensors = args.use_safetensors or ("safetensors" in args.save_model_as.lower()) + + # 乱数系列を初期化する + if args.seed is not None: + set_seed(args.seed) + + # メタデータを読み込む + if os.path.exists(args.in_json): + print(f"loading existing metadata: {args.in_json}") + with open(args.in_json, "rt", encoding='utf-8') as f: + metadata = json.load(f) + else: + print(f"no metadata / メタデータファイルがありません: {args.in_json}") + return + + # tokenizerを読み込む + print("prepare tokenizer") + if args.v2: + tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer") + else: + tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) + + if args.max_token_length is not None: + print(f"update token length: {args.max_token_length}") + + # datasetを用意する + print("prepare dataset") + train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size, + tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens, + args.dataset_repeats, args.debug_dataset) + + print(f"Total dataset length / データセットの長さ: {len(train_dataset)}") + print(f"Total images / 画像数: {train_dataset.images_count}") + + if len(train_dataset) == 0: + print("No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。") + return + + if args.debug_dataset: + train_dataset.show_buckets() + i = 0 + for example in train_dataset: + print(f"image: {example['image_keys']}") + print(f"captions: {example['captions']}") + print(f"latents: {example['latents'].shape}") + print(f"input_ids: {example['input_ids'].shape}") + print(example['input_ids']) + i += 1 + if i >= 8: + break + return + + # acceleratorを準備する + print("prepare accelerator") + if args.logging_dir is None: + log_with = None + logging_dir = None + else: + log_with = "tensorboard" + log_prefix = "" if args.log_prefix is None else args.log_prefix + logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime()) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir) + + # accelerateの互換性問題を解決する + accelerator_0_15 = True + try: + accelerator.unwrap_model("dummy", True) + print("Using accelerator 0.15.0 or above.") + except TypeError: + accelerator_0_15 = False + + def unwrap_model(model): + if accelerator_0_15: + return accelerator.unwrap_model(model, True) + return accelerator.unwrap_model(model) + + # mixed precisionに対応した型を用意しておき適宜castする + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + save_dtype = None + if args.save_precision == "fp16": + save_dtype = torch.float16 + elif args.save_precision == "bf16": + save_dtype = torch.bfloat16 + elif args.save_precision == "float": + save_dtype = torch.float32 + + # モデルを読み込む + if load_stable_diffusion_format: + print("load StableDiffusion checkpoint") + text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path) + else: + print("load Diffusers pretrained models") + pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None) + # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる + text_encoder = pipe.text_encoder + unet = pipe.unet + vae = pipe.vae + del pipe + vae.to("cpu") # 保存時にしか使わないので、メモリを開けるためCPUに移しておく + + # Diffusers版のxformers使用フラグを設定する関数 + def set_diffusers_xformers_flag(model, valid): + # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう + # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) + # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか + # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) + + # Recursively walk through all the children. + # Any children which exposes the set_use_memory_efficient_attention_xformers method + # gets the message + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "set_use_memory_efficient_attention_xformers"): + module.set_use_memory_efficient_attention_xformers(valid) + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + fn_recursive_set_mem_eff(model) + + # モデルに xformers とか memory efficient attention を組み込む + if args.diffusers_xformers: + print("Use xformers by Diffusers") + set_diffusers_xformers_flag(unet, True) + else: + # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある + print("Disable Diffusers' xformers") + set_diffusers_xformers_flag(unet, False) + replace_unet_modules(unet, args.mem_eff_attn, args.xformers) + + if not fine_tuning: + # Hypernetwork + print("import hypernetwork module:", args.hypernetwork_module) + hyp_module = importlib.import_module(args.hypernetwork_module) + + hypernetwork = hyp_module.Hypernetwork() + + if args.hypernetwork_weights is not None: + print("load hypernetwork weights from:", args.hypernetwork_weights) + hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') + success = hypernetwork.load_from_state_dict(hyp_sd) + assert success, "hypernetwork weights loading failed." + + print("apply hypernetwork") + hypernetwork.apply_to_diffusers(None, text_encoder, unet) + + # 学習を準備する:モデルを適切な状態にする + training_models = [] + if fine_tuning: + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + training_models.append(unet) + + if args.train_text_encoder: + print("enable text encoder training") + if args.gradient_checkpointing: + text_encoder.gradient_checkpointing_enable() + training_models.append(text_encoder) else: - src_stable_diffusion_ckpt = None - src_diffusers_model_path = args.pretrained_model_name_or_path + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) # text encoderは学習しない + text_encoder.eval() + else: + unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない + unet.requires_grad_(False) + unet.eval() + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + training_models.append(hypernetwork) - if args.save_model_as is None: - save_stable_diffusion_format = load_stable_diffusion_format - use_safetensors = args.use_safetensors - else: - save_stable_diffusion_format = ( - args.save_model_as.lower() == 'ckpt' - or args.save_model_as.lower() == 'safetensors' - ) - use_safetensors = args.use_safetensors or ( - 'safetensors' in args.save_model_as.lower() - ) + for m in training_models: + m.requires_grad_(True) + params = [] + for m in training_models: + params.extend(m.parameters()) + params_to_optimize = params - # 乱数系列を初期化する - if args.seed is not None: - set_seed(args.seed) + # 学習に必要なクラスを準備する + print("prepare optimizer, data loader etc.") - # メタデータを読み込む - if os.path.exists(args.in_json): - print(f'loading existing metadata: {args.in_json}') - with open(args.in_json, 'rt', encoding='utf-8') as f: - metadata = json.load(f) - else: - print(f'no metadata / メタデータファイルがありません: {args.in_json}') - return - - # tokenizerを読み込む - print('prepare tokenizer') - if args.v2: - tokenizer = CLIPTokenizer.from_pretrained( - V2_STABLE_DIFFUSION_PATH, subfolder='tokenizer' - ) - else: - tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) - - if args.max_token_length is not None: - print(f'update token length: {args.max_token_length}') - - # datasetを用意する - print('prepare dataset') - train_dataset = FineTuningDataset( - metadata, - args.train_data_dir, - args.train_batch_size, - tokenizer, - args.max_token_length, - args.shuffle_caption, - args.keep_tokens, - args.dataset_repeats, - args.debug_dataset, - ) - - print(f'Total dataset length / データセットの長さ: {len(train_dataset)}') - print(f'Total images / 画像数: {train_dataset.images_count}') - - if len(train_dataset) == 0: - print( - 'No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。' - ) - return - - if args.debug_dataset: - train_dataset.show_buckets() - i = 0 - for example in train_dataset: - print(f"image: {example['image_keys']}") - print(f"captions: {example['captions']}") - print(f"latents: {example['latents'].shape}") - print(f"input_ids: {example['input_ids'].shape}") - print(example['input_ids']) - i += 1 - if i >= 8: - break - return - - # acceleratorを準備する - print('prepare accelerator') - if args.logging_dir is None: - log_with = None - logging_dir = None - else: - log_with = 'tensorboard' - log_prefix = '' if args.log_prefix is None else args.log_prefix - logging_dir = ( - args.logging_dir - + '/' - + log_prefix - + time.strftime('%Y%m%d%H%M%S', time.localtime()) - ) - accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, - mixed_precision=args.mixed_precision, - log_with=log_with, - logging_dir=logging_dir, - ) - - # accelerateの互換性問題を解決する - accelerator_0_15 = True + # 8-bit Adamを使う + if args.use_8bit_adam: try: - accelerator.unwrap_model('dummy', True) - print('Using accelerator 0.15.0 or above.') - except TypeError: - accelerator_0_15 = False + import bitsandbytes as bnb + except ImportError: + raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") + print("use 8-bit Adam optimizer") + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW - def unwrap_model(model): - if accelerator_0_15: - return accelerator.unwrap_model(model, True) - return accelerator.unwrap_model(model) + # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 + optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) - # mixed precisionに対応した型を用意しておき適宜castする - weight_dtype = torch.float32 - if args.mixed_precision == 'fp16': - weight_dtype = torch.float16 - elif args.mixed_precision == 'bf16': - weight_dtype = torch.bfloat16 + # dataloaderを準備する + # DataLoaderのプロセス数:0はメインプロセスになる + n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers) - save_dtype = None - if args.save_precision == 'fp16': - save_dtype = torch.float16 - elif args.save_precision == 'bf16': - save_dtype = torch.bfloat16 - elif args.save_precision == 'float': - save_dtype = torch.float32 + # lr schedulerを用意する + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps) - # モデルを読み込む - if load_stable_diffusion_format: - print('load StableDiffusion checkpoint') - ( - text_encoder, - vae, - unet, - ) = model_util.load_models_from_stable_diffusion_checkpoint( - args.v2, args.pretrained_model_name_or_path - ) - else: - print('load Diffusers pretrained models') - pipe = StableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - tokenizer=None, - safety_checker=None, - ) - # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる - text_encoder = pipe.text_encoder - unet = pipe.unet - vae = pipe.vae - del pipe - vae.to('cpu') # 保存時にしか使わないので、メモリを開けるためCPUに移しておく + # acceleratorがなんかよろしくやってくれるらしい + if args.full_fp16: + assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" + print("enable full fp16 training.") - # Diffusers版のxformers使用フラグを設定する関数 - def set_diffusers_xformers_flag(model, valid): - # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう - # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`) - # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか - # 0.10.2でなんか巻き戻って個別に指定するようになった(;^ω^) - - # Recursively walk through all the children. - # Any children which exposes the set_use_memory_efficient_attention_xformers method - # gets the message - def fn_recursive_set_mem_eff(module: torch.nn.Module): - if hasattr(module, 'set_use_memory_efficient_attention_xformers'): - module.set_use_memory_efficient_attention_xformers(valid) - - for child in module.children(): - fn_recursive_set_mem_eff(child) - - fn_recursive_set_mem_eff(model) - - # モデルに xformers とか memory efficient attention を組み込む - if args.diffusers_xformers: - print('Use xformers by Diffusers') - set_diffusers_xformers_flag(unet, True) - else: - # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある - print("Disable Diffusers' xformers") - set_diffusers_xformers_flag(unet, False) - replace_unet_modules(unet, args.mem_eff_attn, args.xformers) - - if not fine_tuning: - # Hypernetwork - print('import hypernetwork module:', args.hypernetwork_module) - hyp_module = importlib.import_module(args.hypernetwork_module) - - hypernetwork = hyp_module.Hypernetwork() - - if args.hypernetwork_weights is not None: - print('load hypernetwork weights from:', args.hypernetwork_weights) - hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu') - success = hypernetwork.load_from_state_dict(hyp_sd) - assert success, 'hypernetwork weights loading failed.' - - print('apply hypernetwork') - hypernetwork.apply_to_diffusers(None, text_encoder, unet) - - # 学習を準備する:モデルを適切な状態にする - training_models = [] - if fine_tuning: - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - training_models.append(unet) - - if args.train_text_encoder: - print('enable text encoder training') - if args.gradient_checkpointing: - text_encoder.gradient_checkpointing_enable() - training_models.append(text_encoder) - else: - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) # text encoderは学習しない - text_encoder.eval() - else: - unet.to( - accelerator.device - ) # , dtype=weight_dtype) # dtypeを指定すると学習できない - unet.requires_grad_(False) - unet.eval() - text_encoder.to(accelerator.device, dtype=weight_dtype) - text_encoder.requires_grad_(False) - text_encoder.eval() - training_models.append(hypernetwork) - - for m in training_models: - m.requires_grad_(True) - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params - - # 学習に必要なクラスを準備する - print('prepare optimizer, data loader etc.') - - # 8-bit Adamを使う - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - 'No bitsand bytes / bitsandbytesがインストールされていないようです' - ) - print('use 8-bit Adam optimizer') - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 - optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate) - - # dataloaderを準備する - # DataLoaderのプロセス数:0はメインプロセスになる - n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8 - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=1, - shuffle=False, - collate_fn=collate_fn, - num_workers=n_workers, - ) - - # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( - args.lr_scheduler, - optimizer, - num_warmup_steps=args.lr_warmup_steps, - num_training_steps=args.max_train_steps - * args.gradient_accumulation_steps, - ) - - # acceleratorがなんかよろしくやってくれるらしい + if fine_tuning: + # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: - assert ( - args.mixed_precision == 'fp16' - ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" - print('enable full fp16 training.') + unet.to(weight_dtype) + text_encoder.to(weight_dtype) - if fine_tuning: - # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする - if args.full_fp16: - unet.to(weight_dtype) - text_encoder.to(weight_dtype) - - if args.train_text_encoder: - ( - unet, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) - else: - ( - unet, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler) else: - if args.full_fp16: - unet.to(weight_dtype) - hypernetwork.to(weight_dtype) - - ( - unet, - hypernetwork, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - unet, hypernetwork, optimizer, train_dataloader, lr_scheduler - ) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + else: if args.full_fp16: - org_unscale_grads = accelerator.scaler._unscale_grads_ + unet.to(weight_dtype) + hypernetwork.to(weight_dtype) - def _unscale_grads_replacer( - optimizer, inv_scale, found_inf, allow_fp16 - ): - return org_unscale_grads(optimizer, inv_scale, found_inf, True) + unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, hypernetwork, optimizer, train_dataloader, lr_scheduler) - accelerator.scaler._unscale_grads_ = _unscale_grads_replacer + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + org_unscale_grads = accelerator.scaler._unscale_grads_ - # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す + def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16): + return org_unscale_grads(optimizer, inv_scale, found_inf, True) - # resumeする - if args.resume is not None: - print(f'resume training from state: {args.resume}') - accelerator.load_state(args.resume) + accelerator.scaler._unscale_grads_ = _unscale_grads_replacer - # epoch数を計算する - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) - num_train_epochs = math.ceil( - args.max_train_steps / num_update_steps_per_epoch - ) + # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す - # 学習する - total_batch_size = ( - args.train_batch_size - * accelerator.num_processes - * args.gradient_accumulation_steps - ) - print('running training / 学習開始') - print(f' num examples / サンプル数: {train_dataset.images_count}') - print(f' num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}') - print(f' num epochs / epoch数: {num_train_epochs}') - print(f' batch size per device / バッチサイズ: {args.train_batch_size}') - print( - f' total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}' - ) - print( - f' gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}' - ) - print(f' total optimization steps / 学習ステップ数: {args.max_train_steps}') + # resumeする + if args.resume is not None: + print(f"resume training from state: {args.resume}") + accelerator.load_state(args.resume) - progress_bar = tqdm( - range(args.max_train_steps), - smoothing=0, - disable=not accelerator.is_local_main_process, - desc='steps', - ) - global_step = 0 + # epoch数を計算する + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) - # v4で更新:clip_sample=Falseに - # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる - # 既存の1.4/1.5/2.0/2.1はすべてschdulerのconfigは(クラス名を除いて)同じ - # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule='scaled_linear', - num_train_timesteps=1000, - clip_sample=False, - ) + # 学習する + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + print("running training / 学習開始") + print(f" num examples / サンプル数: {train_dataset.images_count}") + print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") + print(f" num epochs / epoch数: {num_train_epochs}") + print(f" batch size per device / バッチサイズ: {args.train_batch_size}") + print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") + print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") - if accelerator.is_main_process: - accelerator.init_trackers( - 'finetuning' if fine_tuning else 'hypernetwork' - ) + progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + global_step = 0 - # 以下 train_dreambooth.py からほぼコピペ - for epoch in range(num_train_epochs): - print(f'epoch {epoch+1}/{num_train_epochs}') - for m in training_models: - m.train() + # v4で更新:clip_sample=Falseに + # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる + # 既存の1.4/1.5/2.0/2.1はすべてschedulerのconfigは(クラス名を除いて)同じ + # よくソースを見たら学習時はclip_sampleは関係ないや(;'∀') + noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_train_timesteps=1000, clip_sample=False) - loss_total = 0 - for step, batch in enumerate(train_dataloader): - with accelerator.accumulate( - training_models[0] - ): # 複数モデルに対応していない模様だがとりあえずこうしておく - latents = batch['latents'].to(accelerator.device) - latents = latents * 0.18215 - b_size = latents.shape[0] + if accelerator.is_main_process: + accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork") - # with torch.no_grad(): - with torch.set_grad_enabled(args.train_text_encoder): - # Get the text embedding for conditioning - input_ids = batch['input_ids'].to(accelerator.device) - input_ids = input_ids.reshape( - (-1, tokenizer.model_max_length) - ) # batch_size*3, 77 + # 以下 train_dreambooth.py からほぼコピペ + for epoch in range(num_train_epochs): + print(f"epoch {epoch+1}/{num_train_epochs}") + for m in training_models: + m.train() - if args.clip_skip is None: - encoder_hidden_states = text_encoder(input_ids)[0] - else: - enc_out = text_encoder( - input_ids, - output_hidden_states=True, - return_dict=True, - ) - encoder_hidden_states = enc_out['hidden_states'][ - -args.clip_skip - ] - encoder_hidden_states = ( - text_encoder.text_model.final_layer_norm( - encoder_hidden_states - ) - ) + loss_total = 0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + latents = batch["latents"].to(accelerator.device) + latents = latents * 0.18215 + b_size = latents.shape[0] - # bs*3, 77, 768 or 1024 - encoder_hidden_states = encoder_hidden_states.reshape( - (b_size, -1, encoder_hidden_states.shape[-1]) - ) + # with torch.no_grad(): + with torch.set_grad_enabled(args.train_text_encoder): + # Get the text embedding for conditioning + input_ids = batch["input_ids"].to(accelerator.device) + input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77 - if args.max_token_length is not None: - if args.v2: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [ - encoder_hidden_states[:, 0].unsqueeze(1) - ] # - for i in range( - 1, - args.max_token_length, - tokenizer.model_max_length, - ): - chunk = encoder_hidden_states[ - :, i : i + tokenizer.model_max_length - 2 - ] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if ( - input_ids[j, 1] - == tokenizer.eos_token - ): # 空、つまり ...のパターン - chunk[j, 0] = chunk[ - j, 1 - ] # 次の の値をコピーする - states_list.append( - chunk - ) # の後から の前まで - states_list.append( - encoder_hidden_states[:, -1].unsqueeze(1) - ) # のどちらか - encoder_hidden_states = torch.cat( - states_list, dim=1 - ) - else: - # v1: ... の三連を ... へ戻す - states_list = [ - encoder_hidden_states[:, 0].unsqueeze(1) - ] # - for i in range( - 1, - args.max_token_length, - tokenizer.model_max_length, - ): - states_list.append( - encoder_hidden_states[ - :, - i : i + tokenizer.model_max_length - 2, - ] - ) # の後から の前まで - states_list.append( - encoder_hidden_states[:, -1].unsqueeze(1) - ) # - encoder_hidden_states = torch.cat( - states_list, dim=1 - ) + if args.clip_skip is None: + encoder_hidden_states = text_encoder(input_ids)[0] + else: + enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True) + encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip] + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents, device=latents.device) + # bs*3, 77, 768 or 1024 + encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.config.num_train_timesteps, - (b_size,), - device=latents.device, - ) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise( - latents, noise, timesteps - ) - - # Predict the noise residual - noise_pred = unet( - noisy_latents, timesteps, encoder_hidden_states - ).sample - - if args.v_parameterization: - # v-parameterization training - # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う - target = noise_scheduler.get_velocity( - latents, noise, timesteps - ) - else: - target = noise - - loss = torch.nn.functional.mse_loss( - noise_pred.float(), target.float(), reduction='mean' - ) - - accelerator.backward(loss) - if accelerator.sync_gradients: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_( - params_to_clip, 1.0 - ) # args.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - - current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず - if args.logging_dir is not None: - logs = { - 'loss': current_loss, - 'lr': lr_scheduler.get_last_lr()[0], - } - accelerator.log(logs, step=global_step) - - loss_total += current_loss - avr_loss = loss_total / (step + 1) - logs = {'loss': avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= args.max_train_steps: - break - - if args.logging_dir is not None: - logs = {'epoch_loss': loss_total / len(train_dataloader)} - accelerator.log(logs, step=epoch + 1) - - accelerator.wait_for_everyone() - - if args.save_every_n_epochs is not None: - if (epoch + 1) % args.save_every_n_epochs == 0 and ( - epoch + 1 - ) < num_train_epochs: - print('saving checkpoint.') - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join( - args.output_dir, - model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1), - ) - - if fine_tuning: - if save_stable_diffusion_format: - model_util.save_stable_diffusion_checkpoint( - args.v2, - ckpt_file, - unwrap_model(text_encoder), - unwrap_model(unet), - src_stable_diffusion_ckpt, - epoch + 1, - global_step, - save_dtype, - vae, - ) - else: - out_dir = os.path.join( - args.output_dir, - EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1), - ) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint( - args.v2, - out_dir, - unwrap_model(text_encoder), - unwrap_model(unet), - src_diffusers_model_path, - vae=vae, - use_safetensors=use_safetensors, - ) - else: - save_hypernetwork(ckpt_file, unwrap_model(hypernetwork)) - - if args.save_state: - print('saving state.') - accelerator.save_state( - os.path.join( - args.output_dir, EPOCH_STATE_NAME.format(epoch + 1) - ) - ) - - is_main_process = accelerator.is_main_process - if is_main_process: - if fine_tuning: - unet = unwrap_model(unet) - text_encoder = unwrap_model(text_encoder) - else: - hypernetwork = unwrap_model(hypernetwork) - - accelerator.end_training() - - if args.save_state: - print('saving last state.') - accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) - - del accelerator # この後メモリを使うのでこれは消す - - if is_main_process: - os.makedirs(args.output_dir, exist_ok=True) - ckpt_file = os.path.join( - args.output_dir, model_util.get_last_ckpt_name(use_safetensors) - ) - - if fine_tuning: - if save_stable_diffusion_format: - print( - f'save trained model as StableDiffusion checkpoint to {ckpt_file}' - ) - model_util.save_stable_diffusion_checkpoint( - args.v2, - ckpt_file, - text_encoder, - unet, - src_stable_diffusion_ckpt, - epoch, - global_step, - save_dtype, - vae, - ) + if args.max_token_length is not None: + if args.v2: + # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで + if i > 0: + for j in range(len(chunk)): + if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン + chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする + states_list.append(chunk) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか + encoder_hidden_states = torch.cat(states_list, dim=1) else: - # Create the pipeline using using the trained modules and save it. - print(f'save trained model as Diffusers to {args.output_dir}') - out_dir = os.path.join( - args.output_dir, LAST_DIFFUSERS_DIR_NAME - ) - os.makedirs(out_dir, exist_ok=True) - model_util.save_diffusers_checkpoint( - args.v2, - out_dir, - text_encoder, - unet, - src_diffusers_model_path, - vae=vae, - use_safetensors=use_safetensors, - ) - else: - print(f'save trained model to {ckpt_file}') - save_hypernetwork(ckpt_file, hypernetwork) + # v1: ... の三連を ... へ戻す + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + for i in range(1, args.max_token_length, tokenizer.model_max_length): + states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # + encoder_hidden_states = torch.cat(states_list, dim=1) - print('model saved.') + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents, device=latents.device) + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.v_parameterization: + # v-parameterization training + # Diffusers 0.10.0からv_parameterizationの学習に対応したのでそちらを使う + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず + if args.logging_dir is not None: + logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]} + accelerator.log(logs, step=global_step) + + loss_total += current_loss + avr_loss = loss_total / (step+1) + logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + if args.logging_dir is not None: + logs = {"epoch_loss": loss_total / len(train_dataloader)} + accelerator.log(logs, step=epoch+1) + + accelerator.wait_for_everyone() + + if args.save_every_n_epochs is not None: + if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs: + print("saving checkpoint.") + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(use_safetensors, epoch + 1)) + + if fine_tuning: + if save_stable_diffusion_format: + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, unwrap_model(text_encoder), unwrap_model(unet), + src_stable_diffusion_ckpt, epoch + 1, global_step, save_dtype, vae) + else: + out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1)) + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, unwrap_model(text_encoder), unwrap_model(unet), + src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) + else: + save_hypernetwork(ckpt_file, unwrap_model(hypernetwork)) + + if args.save_state: + print("saving state.") + accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1))) + + is_main_process = accelerator.is_main_process + if is_main_process: + if fine_tuning: + unet = unwrap_model(unet) + text_encoder = unwrap_model(text_encoder) + else: + hypernetwork = unwrap_model(hypernetwork) + + accelerator.end_training() + + if args.save_state: + print("saving last state.") + accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME)) + + del accelerator # この後メモリを使うのでこれは消す + + if is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(use_safetensors)) + + if fine_tuning: + if save_stable_diffusion_format: + print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}") + model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet, + src_stable_diffusion_ckpt, epoch, global_step, save_dtype, vae) + else: + # Create the pipeline using using the trained modules and save it. + print(f"save trained model as Diffusers to {args.output_dir}") + out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME) + os.makedirs(out_dir, exist_ok=True) + model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, + src_diffusers_model_path, vae=vae, use_safetensors=use_safetensors) + else: + print(f"save trained model to {ckpt_file}") + save_hypernetwork(ckpt_file, hypernetwork) + + print("model saved.") # region モジュール入れ替え部 @@ -1010,12 +734,11 @@ EPSILON = 1e-6 def exists(val): - return val is not None + return val is not None def default(val, d): - return val if exists(val) else d - + return val if exists(val) else d # flash attention forwards and backwards @@ -1023,516 +746,314 @@ def default(val, d): class FlashAttentionFunction(torch.autograd.function.Function): - @staticmethod - @torch.no_grad() - def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): - """Algorithm 2 in the paper""" + @ staticmethod + @ torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """ Algorithm 2 in the paper """ - device = q.device - dtype = q.dtype - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - o = torch.zeros_like(q) - all_row_sums = torch.zeros( - (*q.shape[:-1], 1), dtype=dtype, device=device - ) - all_row_maxes = torch.full( - (*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device - ) + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) - scale = q.shape[-1] ** -0.5 + scale = (q.shape[-1] ** -0.5) - if not exists(mask): - mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) - else: - mask = rearrange(mask, 'b n -> b 1 1 n') - mask = mask.split(q_bucket_size, dim=-1) + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, 'b n -> b 1 1 n') + mask = mask.split(q_bucket_size, dim=-1) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - mask, - all_row_sums.split(q_bucket_size, dim=-2), - all_row_maxes.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) - for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate( - row_splits - ): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = ( - einsum('... i d, ... j d -> ... i j', qc, kc) * scale - ) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - if exists(row_mask): - attn_weights.masked_fill_(~row_mask, max_neg_value) + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) - if causal and q_start_index < ( - k_start_index + k_bucket_size - 1 - ): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), - dtype=torch.bool, - device=device, - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) - block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) - attn_weights -= block_row_maxes - exp_weights = torch.exp(attn_weights) + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) - if exists(row_mask): - exp_weights.masked_fill_(~row_mask, 0.0) + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.) - block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp( - min=EPSILON - ) + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) - new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) - exp_values = einsum( - '... i j, ... j d -> ... i d', exp_weights, vc - ) + exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc) - exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) - exp_block_row_max_diff = torch.exp( - block_row_maxes - new_row_maxes - ) + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) - new_row_sums = ( - exp_row_max_diff * row_sums - + exp_block_row_max_diff * block_row_sums - ) + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums - oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_( - (exp_block_row_max_diff / new_row_sums) * exp_values - ) + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) - row_maxes.copy_(new_row_maxes) - row_sums.copy_(new_row_sums) + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) - ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) - ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) - return o + return o - @staticmethod - @torch.no_grad() - def backward(ctx, do): - """Algorithm 4 in the paper""" + @ staticmethod + @ torch.no_grad() + def backward(ctx, do): + """ Algorithm 4 in the paper """ - causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args - q, k, v, o, l, m = ctx.saved_tensors + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors - device = q.device + device = q.device - max_neg_value = -torch.finfo(q.dtype).max - qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) - dq = torch.zeros_like(q) - dk = torch.zeros_like(k) - dv = torch.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) - row_splits = zip( - q.split(q_bucket_size, dim=-2), - o.split(q_bucket_size, dim=-2), - do.split(q_bucket_size, dim=-2), - mask, - l.split(q_bucket_size, dim=-2), - m.split(q_bucket_size, dim=-2), - dq.split(q_bucket_size, dim=-2), - ) + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2) + ) - for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): - q_start_index = ind * q_bucket_size - qk_len_diff + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff - col_splits = zip( - k.split(k_bucket_size, dim=-2), - v.split(k_bucket_size, dim=-2), - dk.split(k_bucket_size, dim=-2), - dv.split(k_bucket_size, dim=-2), - ) + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) - for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): - k_start_index = k_ind * k_bucket_size + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size - attn_weights = ( - einsum('... i d, ... j d -> ... i j', qc, kc) * scale - ) + attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale - if causal and q_start_index < ( - k_start_index + k_bucket_size - 1 - ): - causal_mask = torch.ones( - (qc.shape[-2], kc.shape[-2]), - dtype=torch.bool, - device=device, - ).triu(q_start_index - k_start_index + 1) - attn_weights.masked_fill_(causal_mask, max_neg_value) + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, + device=device).triu(q_start_index - k_start_index + 1) + attn_weights.masked_fill_(causal_mask, max_neg_value) - exp_attn_weights = torch.exp(attn_weights - mc) + exp_attn_weights = torch.exp(attn_weights - mc) - if exists(row_mask): - exp_attn_weights.masked_fill_(~row_mask, 0.0) + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.) - p = exp_attn_weights / lc + p = exp_attn_weights / lc - dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) - dp = einsum('... i d, ... j d -> ... i j', doc, vc) + dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc) + dp = einsum('... i d, ... j d -> ... i j', doc, vc) - D = (doc * oc).sum(dim=-1, keepdims=True) - ds = p * scale * (dp - D) + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) - dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) - dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) + dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc) + dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc) - dqc.add_(dq_chunk) - dkc.add_(dk_chunk) - dvc.add_(dv_chunk) + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) - return dq, dk, dv, None, None, None, None + return dq, dk, dv, None, None, None, None -def replace_unet_modules( - unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, - mem_eff_attn, - xformers, -): - if mem_eff_attn: - replace_unet_cross_attn_to_memory_efficient() - elif xformers: - replace_unet_cross_attn_to_xformers() +def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers): + if mem_eff_attn: + replace_unet_cross_attn_to_memory_efficient() + elif xformers: + replace_unet_cross_attn_to_xformers() def replace_unet_cross_attn_to_memory_efficient(): - print( - 'Replace CrossAttention.forward to use FlashAttention (not xformers)' - ) - flash_func = FlashAttentionFunction + print("Replace CrossAttention.forward to use FlashAttention (not xformers)") + flash_func = FlashAttentionFunction - def forward_flash_attn(self, x, context=None, mask=None): - q_bucket_size = 512 - k_bucket_size = 1024 + def forward_flash_attn(self, x, context=None, mask=None): + q_bucket_size = 512 + k_bucket_size = 1024 - h = self.heads - q = self.to_q(x) + h = self.heads + q = self.to_q(x) - context = context if context is not None else x - context = context.to(x.dtype) + context = context if context is not None else x + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, x - q, k, v = map( - lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v) - ) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - out = flash_func.apply( - q, k, v, mask, False, q_bucket_size, k_bucket_size - ) + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, 'b h n d -> b n (h d)') - # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) - out = self.to_out[0](out) - out = self.to_out[1](out) - return out + # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`) + out = self.to_out[0](out) + out = self.to_out[1](out) + return out - diffusers.models.attention.CrossAttention.forward = forward_flash_attn + diffusers.models.attention.CrossAttention.forward = forward_flash_attn def replace_unet_cross_attn_to_xformers(): - print('Replace CrossAttention.forward to use xformers') - try: - import xformers.ops - except ImportError: - raise ImportError('No xformers / xformersがインストールされていないようです') + print("Replace CrossAttention.forward to use xformers") + try: + import xformers.ops + except ImportError: + raise ImportError("No xformers / xformersがインストールされていないようです") - def forward_xformers(self, x, context=None, mask=None): - h = self.heads - q_in = self.to_q(x) + def forward_xformers(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) - context = default(context, x) - context = context.to(x.dtype) + context = default(context, x) + context = context.to(x.dtype) - if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: - context_k, context_v = self.hypernetwork.forward(x, context) - context_k = context_k.to(x.dtype) - context_v = context_v.to(x.dtype) - else: - context_k = context - context_v = context + if hasattr(self, 'hypernetwork') and self.hypernetwork is not None: + context_k, context_v = self.hypernetwork.forward(x, context) + context_k = context_k.to(x.dtype) + context_v = context_v.to(x.dtype) + else: + context_k = context + context_v = context - k_in = self.to_k(context_k) - v_in = self.to_v(context_v) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) - q, k, v = map( - lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), - (q_in, k_in, v_in), - ) - del q_in, k_in, v_in + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None - ) # 最適なのを選んでくれる + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる - out = rearrange(out, 'b n h d -> b n (h d)', h=h) - - # diffusers 0.7.0~ - out = self.to_out[0](out) - out = self.to_out[1](out) - return out - - diffusers.models.attention.CrossAttention.forward = forward_xformers + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + # diffusers 0.7.0~ + out = self.to_out[0](out) + out = self.to_out[1](out) + return out + diffusers.models.attention.CrossAttention.forward = forward_xformers # endregion if __name__ == '__main__': - # torch.cuda.set_per_process_memory_fraction(0.48) - parser = argparse.ArgumentParser() - parser.add_argument( - '--v2', - action='store_true', - help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む', - ) - parser.add_argument( - '--v_parameterization', - action='store_true', - help='enable v-parameterization training / v-parameterization学習を有効にする', - ) - parser.add_argument( - '--pretrained_model_name_or_path', - type=str, - default=None, - help='pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル', - ) - parser.add_argument( - '--in_json', - type=str, - default=None, - help='metadata file to input / 読みこむメタデータファイル', - ) - parser.add_argument( - '--shuffle_caption', - action='store_true', - help='shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする', - ) - parser.add_argument( - '--keep_tokens', - type=int, - default=None, - help='keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す', - ) - parser.add_argument( - '--train_data_dir', - type=str, - default=None, - help='directory for train images / 学習画像データのディレクトリ', - ) - parser.add_argument( - '--dataset_repeats', - type=int, - default=None, - help='num times to repeat dataset / 学習にデータセットを繰り返す回数', - ) - parser.add_argument( - '--output_dir', - type=str, - default=None, - help='directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)', - ) - parser.add_argument( - '--save_precision', - type=str, - default=None, - choices=[None, 'float', 'fp16', 'bf16'], - help='precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)', - ) - parser.add_argument( - '--save_model_as', - type=str, - default=None, - choices=[ - None, - 'ckpt', - 'safetensors', - 'diffusers', - 'diffusers_safetensors', - ], - help='format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)', - ) - parser.add_argument( - '--use_safetensors', - action='store_true', - help='use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)', - ) - parser.add_argument( - '--train_text_encoder', - action='store_true', - help='train text encoder / text encoderも学習する', - ) - parser.add_argument( - '--hypernetwork_module', - type=str, - default=None, - help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール', - ) - parser.add_argument( - '--hypernetwork_weights', - type=str, - default=None, - help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)', - ) - parser.add_argument( - '--save_every_n_epochs', - type=int, - default=None, - help='save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する', - ) - parser.add_argument( - '--save_state', - action='store_true', - help='save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する', - ) - parser.add_argument( - '--resume', - type=str, - default=None, - help='saved state to resume training / 学習再開するモデルのstate', - ) - parser.add_argument( - '--max_token_length', - type=int, - default=None, - choices=[None, 150, 225], - help='max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)', - ) - parser.add_argument( - '--train_batch_size', - type=int, - default=1, - help='batch size for training / 学習時のバッチサイズ', - ) - parser.add_argument( - '--use_8bit_adam', - action='store_true', - help='use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)', - ) - parser.add_argument( - '--mem_eff_attn', - action='store_true', - help='use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う', - ) - parser.add_argument( - '--xformers', - action='store_true', - help='use xformers for CrossAttention / CrossAttentionにxformersを使う', - ) - parser.add_argument( - '--diffusers_xformers', - action='store_true', - help="use xformers by diffusers (Hypernetworks doesn't work) / Diffusersでxformersを使用する(Hypernetwork利用不可)", - ) - parser.add_argument( - '--learning_rate', - type=float, - default=2.0e-6, - help='learning rate / 学習率', - ) - parser.add_argument( - '--max_train_steps', - type=int, - default=1600, - help='training steps / 学習ステップ数', - ) - parser.add_argument( - '--seed', - type=int, - default=None, - help='random seed for training / 学習時の乱数のseed', - ) - parser.add_argument( - '--gradient_checkpointing', - action='store_true', - help='enable gradient checkpointing / grandient checkpointingを有効にする', - ) - parser.add_argument( - '--gradient_accumulation_steps', - type=int, - default=1, - help='Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数', - ) - parser.add_argument( - '--mixed_precision', - type=str, - default='no', - choices=['no', 'fp16', 'bf16'], - help='use mixed precision / 混合精度を使う場合、その精度', - ) - parser.add_argument( - '--full_fp16', - action='store_true', - help='fp16 training including gradients / 勾配も含めてfp16で学習する', - ) - parser.add_argument( - '--clip_skip', - type=int, - default=None, - help='use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)', - ) - parser.add_argument( - '--debug_dataset', - action='store_true', - help='show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)', - ) - parser.add_argument( - '--logging_dir', - type=str, - default=None, - help='enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する', - ) - parser.add_argument( - '--log_prefix', - type=str, - default=None, - help='add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列', - ) - parser.add_argument( - '--lr_scheduler', - type=str, - default='constant', - help='scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup', - ) - parser.add_argument( - '--lr_warmup_steps', - type=int, - default=0, - help='Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)', - ) + # torch.cuda.set_per_process_memory_fraction(0.48) + parser = argparse.ArgumentParser() + parser.add_argument("--v2", action='store_true', + help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む') + parser.add_argument("--v_parameterization", action='store_true', + help='enable v-parameterization training / v-parameterization学習を有効にする') + parser.add_argument("--pretrained_model_name_or_path", type=str, default=None, + help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル") + parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル") + parser.add_argument("--shuffle_caption", action="store_true", + help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする") + parser.add_argument("--keep_tokens", type=int, default=None, + help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") + parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") + parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数") + parser.add_argument("--output_dir", type=str, default=None, + help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)") + parser.add_argument("--save_precision", type=str, default=None, + choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)") + parser.add_argument("--save_model_as", type=str, default=None, choices=[None, "ckpt", "safetensors", "diffusers", "diffusers_safetensors"], + help="format to save the model (default is same to original) / モデル保存時の形式(未指定時は元モデルと同じ)") + parser.add_argument("--use_safetensors", action='store_true', + help="use safetensors format to save (if save_model_as is not specified) / checkpoint、モデルをsafetensors形式で保存する(save_model_as未指定時)") + parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") + parser.add_argument("--hypernetwork_module", type=str, default=None, + help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール') + parser.add_argument("--hypernetwork_weights", type=str, default=None, + help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)') + parser.add_argument("--save_every_n_epochs", type=int, default=None, + help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") + parser.add_argument("--save_state", action="store_true", + help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") + parser.add_argument("--resume", type=str, default=None, + help="saved state to resume training / 学習再開するモデルのstate") + parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225], + help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)") + parser.add_argument("--train_batch_size", type=int, default=1, + help="batch size for training / 学習時のバッチサイズ") + parser.add_argument("--use_8bit_adam", action="store_true", + help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)") + parser.add_argument("--mem_eff_attn", action="store_true", + help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う") + parser.add_argument("--xformers", action="store_true", + help="use xformers for CrossAttention / CrossAttentionにxformersを使う") + parser.add_argument("--diffusers_xformers", action='store_true', + help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") + parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") + parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") + parser.add_argument("--gradient_checkpointing", action="store_true", + help="enable gradient checkpointing / grandient checkpointingを有効にする") + parser.add_argument("--gradient_accumulation_steps", type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数") + parser.add_argument("--mixed_precision", type=str, default="no", + choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度") + parser.add_argument("--full_fp16", action="store_true", help="fp16 training including gradients / 勾配も含めてfp16で学習する") + parser.add_argument("--clip_skip", type=int, default=None, + help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)") + parser.add_argument("--debug_dataset", action="store_true", + help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)") + parser.add_argument("--logging_dir", type=str, default=None, + help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する") + parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列") + parser.add_argument("--lr_scheduler", type=str, default="constant", + help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup") + parser.add_argument("--lr_warmup_steps", type=int, default=0, + help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)") - args = parser.parse_args() - train(args) + args = parser.parse_args() + train(args) \ No newline at end of file diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a89aed0..edc007e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2486,9 +2486,9 @@ if __name__ == '__main__': parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--diffusers_xformers", action='store_true', - help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') + help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)') parser.add_argument("--opt_channels_last", action='store_true', - help='set channels last option to model / モデルにchannles lastを指定し最適化する') + help='set channels last option to model / モデルにchannels lastを指定し最適化する') parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み') parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率') @@ -2514,4 +2514,4 @@ if __name__ == '__main__': help="save 1st stage images for highres fix / highres fixの最初のステージの画像を保存する") args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/library/model_util.py b/library/model_util.py index 29d4420..ad2b427 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -5,12 +5,7 @@ import math import os import torch from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig -from diffusers import ( - AutoencoderKL, - DDIMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from safetensors.torch import load_file, save_file # DiffUsers版StableDiffusionのモデルパラメータ @@ -41,8 +36,8 @@ V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] V2_UNET_PARAMS_CONTEXT_DIM = 1024 # Diffusersの設定を読み込むための参照モデル -DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5' -DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1' +DIFFUSERS_REF_MODEL_ID_V1 = "runwayml/stable-diffusion-v1-5" +DIFFUSERS_REF_MODEL_ID_V2 = "stabilityai/stable-diffusion-2-1" # region StableDiffusion->Diffusersの変換コード @@ -50,845 +45,596 @@ DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1' def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return '.'.join(path.split('.')[n_shave_prefix_segments:]) - else: - return '.'.join(path.split('.')[:n_shave_prefix_segments]) + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace('in_layers.0', 'norm1') - new_item = new_item.replace('in_layers.2', 'conv1') + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") - new_item = new_item.replace('out_layers.0', 'norm2') - new_item = new_item.replace('out_layers.3', 'conv2') + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") - new_item = new_item.replace('emb_layers.1', 'time_emb_proj') - new_item = new_item.replace('skip_connection', 'conv_shortcut') + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace('nin_shortcut', 'conv_shortcut') - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item - new_item = new_item.replace('norm.weight', 'group_norm.weight') - new_item = new_item.replace('norm.bias', 'group_norm.bias') + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace('q.weight', 'query.weight') - new_item = new_item.replace('q.bias', 'query.bias') + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") - new_item = new_item.replace('k.weight', 'key.weight') - new_item = new_item.replace('k.bias', 'key.bias') + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") - new_item = new_item.replace('v.weight', 'value.weight') - new_item = new_item.replace('v.bias', 'value.bias') + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") - new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - new_item = shave_segments( - new_item, n_shave_prefix_segments=n_shave_prefix_segments - ) + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - mapping.append({'old': old_item, 'new': new_item}) + mapping.append({"old": old_item, "new": new_item}) - return mapping + return mapping def assign_to_checkpoint( - paths, - checkpoint, - old_checkpoint, - attention_paths_to_split=None, - additional_replacements=None, - config=None, + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None ): - """ - This does the final conversion step: take locally converted weights and apply a global renaming - to them. It splits attention layers, and takes into account additional replacements - that may arise. + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. - Assigns the weights to the new checkpoint. - """ - assert isinstance( - paths, list - ), "Paths should be a list of dicts containing 'old' and 'new' keys." + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 - target_shape = ( - (-1, channels) if len(old_tensor.shape) == 3 else (-1) - ) + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3 + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - old_tensor = old_tensor.reshape( - (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] - ) - query, key, value = old_tensor.split(channels // num_heads, dim=1) + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map['query']] = query.reshape(target_shape) - checkpoint[path_map['key']] = key.reshape(target_shape) - checkpoint[path_map['value']] = value.reshape(target_shape) + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) - for path in paths: - new_path = path['new'] + for path in paths: + new_path = path["new"] - # These have already been assigned - if ( - attention_paths_to_split is not None - and new_path in attention_paths_to_split - ): - continue + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue - # Global renaming happens here - new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0') - new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0') - new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1') + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace( - replacement['old'], replacement['new'] - ) + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) - # proj_attn.weight has to be converted from conv 1D to linear - if 'proj_attn.weight' in new_path: - checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path['old']] + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ['query.weight', 'key.weight', 'value.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif 'proj_attn.weight' in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] def linear_transformer_to_conv(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ['proj_in.weight', 'proj_out.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in tf_keys: - if checkpoint[key].ndim == 2: - checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) def convert_ldm_unet_checkpoint(v2, checkpoint, config): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ - # extract state_dict for UNet - unet_state_dict = {} - unet_key = 'model.diffusion_model.' - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key) + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[ - 'time_embed.0.weight' - ] - new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[ - 'time_embed.0.bias' - ] - new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[ - 'time_embed.2.weight' - ] - new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[ - 'time_embed.2.bias' + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - new_checkpoint['conv_in.weight'] = unet_state_dict[ - 'input_blocks.0.0.weight' - ] - new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias'] + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) - new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight'] - new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias'] - new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight'] - new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias'] - - # Retrieves the keys for the input blocks only - num_input_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'input_blocks' in layer - } + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) - input_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'input_blocks.{layer_id}.' in key - ] - for layer_id in range(num_input_blocks) - } - # Retrieves the keys for the middle blocks only - num_middle_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'middle_block' in layer - } - ) - middle_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'middle_block.{layer_id}.' in key - ] - for layer_id in range(num_middle_blocks) - } + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) - # Retrieves the keys for the output blocks only - num_output_blocks = len( - { - '.'.join(layer.split('.')[:2]) - for layer in unet_state_dict - if 'output_blocks' in layer - } - ) - output_blocks = { - layer_id: [ - key - for key in unet_state_dict - if f'output_blocks.{layer_id}.' in key - ] - for layer_id in range(num_output_blocks) - } + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config['layers_per_block'] + 1) - layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1) + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - resnets = [ - key - for key in input_blocks[i] - if f'input_blocks.{i}.0' in key - and f'input_blocks.{i}.0.op' not in key + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" ] - attentions = [ - key for key in input_blocks[i] if f'input_blocks.{i}.1' in key + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" ] - if f'input_blocks.{i}.0.op.weight' in unet_state_dict: - new_checkpoint[ - f'down_blocks.{block_id}.downsamplers.0.conv.weight' - ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight') - new_checkpoint[ - f'down_blocks.{block_id}.downsamplers.0.conv.bias' - ] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias') + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] - paths = renew_resnet_paths(resnets) + if len(attentions): + paths = renew_attention_paths(attentions) meta_path = { - 'old': f'input_blocks.{i}.0', - 'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}', + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", } assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - 'old': f'input_blocks.{i}.1', - 'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) + new_checkpoint[new_path] = unet_state_dict[old_path] - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] + # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する + if v2: + linear_transformer_to_conv(new_checkpoint) - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint( - resnet_0_paths, new_checkpoint, unet_state_dict, config=config - ) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint( - resnet_1_paths, new_checkpoint, unet_state_dict, config=config - ) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - attentions_paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - for i in range(num_output_blocks): - block_id = i // (config['layers_per_block'] + 1) - layer_in_block_id = i % (config['layers_per_block'] + 1) - output_block_layers = [ - shave_segments(name, 2) for name in output_blocks[i] - ] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split('.')[0], shave_segments( - layer, 1 - ) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [ - key - for key in output_blocks[i] - if f'output_blocks.{i}.0' in key - ] - attentions = [ - key - for key in output_blocks[i] - if f'output_blocks.{i}.1' in key - ] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = { - 'old': f'output_blocks.{i}.0', - 'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - # オリジナル: - # if ["conv.weight", "conv.bias"] in output_block_list.values(): - # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) - - # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが - for l in output_block_list.values(): - l.sort() - - if ['conv.bias', 'conv.weight'] in output_block_list.values(): - index = list(output_block_list.values()).index( - ['conv.bias', 'conv.weight'] - ) - new_checkpoint[ - f'up_blocks.{block_id}.upsamplers.0.conv.bias' - ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias'] - new_checkpoint[ - f'up_blocks.{block_id}.upsamplers.0.conv.weight' - ] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight'] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - 'old': f'output_blocks.{i}.1', - 'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - unet_state_dict, - additional_replacements=[meta_path], - config=config, - ) - else: - resnet_0_paths = renew_resnet_paths( - output_block_layers, n_shave_prefix_segments=1 - ) - for path in resnet_0_paths: - old_path = '.'.join(['output_blocks', str(i), path['old']]) - new_path = '.'.join( - [ - 'up_blocks', - str(block_id), - 'resnets', - str(layer_in_block_id), - path['new'], - ] - ) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する - if v2: - linear_transformer_to_conv(new_checkpoint) - - return new_checkpoint + return new_checkpoint def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = 'first_stage_model.' - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key) - # if len(vae_state_dict) == 0: - # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict - # vae_state_dict = checkpoint + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint - new_checkpoint = {} + new_checkpoint = {} - new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[ - 'encoder.conv_in.weight' - ] - new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[ - 'encoder.conv_in.bias' - ] - new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[ - 'encoder.conv_out.weight' - ] - new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[ - 'encoder.conv_out.bias' - ] - new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[ - 'encoder.norm_out.weight' - ] - new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[ - 'encoder.norm_out.bias' + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key ] - new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[ - 'decoder.conv_in.weight' - ] - new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[ - 'decoder.conv_in.bias' - ] - new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[ - 'decoder.conv_out.weight' - ] - new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[ - 'decoder.conv_out.bias' - ] - new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[ - 'decoder.norm_out.weight' - ] - new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[ - 'decoder.norm_out.bias' - ] + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] - new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight'] - new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias'] - new_checkpoint['post_quant_conv.weight'] = vae_state_dict[ - 'post_quant_conv.weight' - ] - new_checkpoint['post_quant_conv.bias'] = vae_state_dict[ - 'post_quant_conv.bias' - ] + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len( - { - '.'.join(layer.split('.')[:3]) - for layer in vae_state_dict - if 'encoder.down' in layer - } - ) - down_blocks = { - layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key] - for layer_id in range(num_down_blocks) - } + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len( - { - '.'.join(layer.split('.')[:3]) - for layer in vae_state_dict - if 'decoder.up' in layer - } - ) - up_blocks = { - layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key] - for layer_id in range(num_up_blocks) - } + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - for i in range(num_down_blocks): - resnets = [ - key - for key in down_blocks[i] - if f'down.{i}' in key and f'down.{i}.downsample' not in key - ] - - if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict: - new_checkpoint[ - f'encoder.down_blocks.{i}.downsamplers.0.conv.weight' - ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight') - new_checkpoint[ - f'encoder.down_blocks.{i}.downsamplers.0.conv.bias' - ] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias') - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'down.{i}.block', - 'new': f'down_blocks.{i}.resnets', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [ - key for key in mid_resnets if f'encoder.mid.block_{i}' in key - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'mid.block_{i}', - 'new': f'mid_block.resnets.{i - 1}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [ - key for key in vae_state_dict if 'encoder.mid.attn' in key - ] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key - for key in up_blocks[block_id] - if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key - ] - - if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict: - new_checkpoint[ - f'decoder.up_blocks.{i}.upsamplers.0.conv.weight' - ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight'] - new_checkpoint[ - f'decoder.up_blocks.{i}.upsamplers.0.conv.bias' - ] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias'] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'up.{block_id}.block', - 'new': f'up_blocks.{i}.resnets', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [ - key for key in mid_resnets if f'decoder.mid.block_{i}' in key - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = { - 'old': f'mid.block_{i}', - 'new': f'mid_block.resnets.{i - 1}', - } - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - - mid_attentions = [ - key for key in vae_state_dict if 'decoder.mid.attn' in key - ] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'} - assign_to_checkpoint( - paths, - new_checkpoint, - vae_state_dict, - additional_replacements=[meta_path], - config=config, - ) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint def create_unet_diffusers_config(v2): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # unet_params = original_config.model.params.unet_config.params + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params - block_out_channels = [ - UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT - ] + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = ( - 'CrossAttnDownBlock2D' - if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS - else 'DownBlock2D' - ) - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = ( - 'CrossAttnUpBlock2D' - if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS - else 'UpBlock2D' - ) - up_block_types.append(block_type) - resolution //= 2 + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 - config = dict( - sample_size=UNET_PARAMS_IMAGE_SIZE, - in_channels=UNET_PARAMS_IN_CHANNELS, - out_channels=UNET_PARAMS_OUT_CHANNELS, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, - cross_attention_dim=UNET_PARAMS_CONTEXT_DIM - if not v2 - else V2_UNET_PARAMS_CONTEXT_DIM, - attention_head_dim=UNET_PARAMS_NUM_HEADS - if not v2 - else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, - ) + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + ) - return config + return config def create_vae_diffusers_config(): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - # vae_params = original_config.model.params.first_stage_config.params.ddconfig - # _ = original_config.model.params.first_stage_config.params.embed_dim - block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] - down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels) - up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels) + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - config = dict( - sample_size=VAE_PARAMS_RESOLUTION, - in_channels=VAE_PARAMS_IN_CHANNELS, - out_channels=VAE_PARAMS_OUT_CH, - down_block_types=tuple(down_block_types), - up_block_types=tuple(up_block_types), - block_out_channels=tuple(block_out_channels), - latent_channels=VAE_PARAMS_Z_CHANNELS, - layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, - ) - return config + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config def convert_ldm_clip_checkpoint_v1(checkpoint): - keys = list(checkpoint.keys()) - text_model_dict = {} - for key in keys: - if key.startswith('cond_stage_model.transformer'): - text_model_dict[ - key[len('cond_stage_model.transformer.') :] - ] = checkpoint[key] - return text_model_dict + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key] + return text_model_dict def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): - # 嫌になるくらい違うぞ! - def convert_key(key): - if not key.startswith('cond_stage_model'): - return None + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None - # common conversion - key = key.replace( - 'cond_stage_model.model.transformer.', 'text_model.encoder.' - ) - key = key.replace('cond_stage_model.model.', 'text_model.') + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") - if 'resblocks' in key: - # resblocks conversion - key = key.replace('.resblocks.', '.layers.') - if '.ln_' in key: - key = key.replace('.ln_', '.layer_norm') - elif '.mlp.' in key: - key = key.replace('.c_fc.', '.fc1.') - key = key.replace('.c_proj.', '.fc2.') - elif '.attn.out_proj' in key: - key = key.replace('.attn.out_proj.', '.self_attn.out_proj.') - elif '.attn.in_proj' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f'unexpected key in SD: {key}') - elif '.positional_embedding' in key: - key = key.replace( - '.positional_embedding', - '.embeddings.position_embedding.weight', - ) - elif '.text_projection' in key: - key = None # 使われない??? - elif '.logit_scale' in key: - key = None # 使われない??? - elif '.token_embedding' in key: - key = key.replace( - '.token_embedding.weight', '.embeddings.token_embedding.weight' - ) - elif '.ln_final' in key: - key = key.replace('.ln_final', '.final_layer_norm') - return key + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif '.attn.out_proj' in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif '.attn.in_proj' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif '.positional_embedding' in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif '.text_projection' in key: + key = None # 使われない??? + elif '.logit_scale' in key: + key = None # 使われない??? + elif '.token_embedding' in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif '.ln_final' in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key - keys = list(checkpoint.keys()) - new_sd = {} - for key in keys: - # remove resblocks 23 - if '.resblocks.23.' in key: - continue - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if '.resblocks.23.' in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] - # attnの変換 - for key in keys: - if '.resblocks.23.' in key: - continue - if '.resblocks' in key and '.attn.in_proj_' in key: - # 三つに分割 - values = torch.chunk(checkpoint[key], 3) + # attnの変換 + for key in keys: + if '.resblocks.23.' in key: + continue + if '.resblocks' in key and '.attn.in_proj_' in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) - key_suffix = '.weight' if 'weight' in key else '.bias' - key_pfx = key.replace( - 'cond_stage_model.model.transformer.resblocks.', - 'text_model.encoder.layers.', - ) - key_pfx = key_pfx.replace('_weight', '') - key_pfx = key_pfx.replace('_bias', '') - key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.') - new_sd[key_pfx + 'q_proj' + key_suffix] = values[0] - new_sd[key_pfx + 'k_proj' + key_suffix] = values[1] - new_sd[key_pfx + 'v_proj' + key_suffix] = values[2] - - # position_idsの追加 - new_sd['text_model.embeddings.position_ids'] = torch.Tensor( - [list(range(max_length))] - ).to(torch.int64) - return new_sd + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd # endregion @@ -896,649 +642,549 @@ def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): # region Diffusers->StableDiffusion の変換コード # convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) - def conv_transformer_to_linear(checkpoint): - keys = list(checkpoint.keys()) - tf_keys = ['proj_in.weight', 'proj_out.weight'] - for key in keys: - if '.'.join(key.split('.')[-2:]) in tf_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] def convert_unet_state_dict_to_sd(v2, unet_state_dict): - unet_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ('time_embed.0.weight', 'time_embedding.linear_1.weight'), - ('time_embed.0.bias', 'time_embedding.linear_1.bias'), - ('time_embed.2.weight', 'time_embedding.linear_2.weight'), - ('time_embed.2.bias', 'time_embedding.linear_2.bias'), - ('input_blocks.0.0.weight', 'conv_in.weight'), - ('input_blocks.0.0.bias', 'conv_in.bias'), - ('out.0.weight', 'conv_norm_out.weight'), - ('out.0.bias', 'conv_norm_out.bias'), - ('out.2.weight', 'conv_out.weight'), - ('out.2.bias', 'conv_out.bias'), - ] + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] - unet_conversion_map_resnet = [ - # (stable-diffusion, HF Diffusers) - ('in_layers.0', 'norm1'), - ('in_layers.2', 'conv1'), - ('out_layers.0', 'norm2'), - ('out_layers.3', 'conv2'), - ('emb_layers.1', 'time_emb_proj'), - ('skip_connection', 'conv_shortcut'), - ] + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] - unet_conversion_map_layer = [] - for i in range(4): - # loop over downblocks/upblocks - - for j in range(2): - # loop over resnets/attentions for downblocks - hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' - sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' - unet_conversion_map_layer.append( - (sd_down_res_prefix, hf_down_res_prefix) - ) - - if i < 3: - # no attention layers in down_blocks.3 - hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' - sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' - unet_conversion_map_layer.append( - (sd_down_atn_prefix, hf_down_atn_prefix) - ) - - for j in range(3): - # loop over resnets/attentions for upblocks - hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' - sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' - unet_conversion_map_layer.append( - (sd_up_res_prefix, hf_up_res_prefix) - ) - - if i > 0: - # no attention layers in up_blocks.0 - hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' - sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' - unet_conversion_map_layer.append( - (sd_up_atn_prefix, hf_up_atn_prefix) - ) - - if i < 3: - # no downsample in down_blocks.3 - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' - sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' - unet_conversion_map_layer.append( - (sd_downsample_prefix, hf_downsample_prefix) - ) - - # no upsample in up_blocks.3 - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = ( - f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' - ) - unet_conversion_map_layer.append( - (sd_upsample_prefix, hf_upsample_prefix) - ) - - hf_mid_atn_prefix = 'mid_block.attentions.0.' - sd_mid_atn_prefix = 'middle_block.1.' - unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks for j in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{j}.' - sd_mid_res_prefix = f'middle_block.{2*j}.' - unet_conversion_map_layer.append( - (sd_mid_res_prefix, hf_mid_res_prefix) - ) + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) - # buyer beware: this is a *brittle* function, - # and correct output requires that all of these pieces interact in - # the exact order in which I have arranged them. - mapping = {k: k for k in unet_state_dict.keys()} - for sd_name, hf_name in unet_conversion_map: - mapping[hf_name] = sd_name - for k, v in mapping.items(): - if 'resnets' in k: - for sd_part, hf_part in unet_conversion_map_resnet: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - for sd_part, hf_part in unet_conversion_map_layer: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) - if v2: - conv_transformer_to_linear(new_state_dict) + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) - return new_state_dict + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict # ================# # VAE Conversion # # ================# - def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights - return w.reshape(*w.shape, 1, 1) + return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - vae_conversion_map = [ - # (stable-diffusion, HF Diffusers) - ('nin_shortcut', 'conv_shortcut'), - ('norm_out', 'conv_norm_out'), - ('mid.attn_1.', 'mid_block.attentions.0.'), - ] + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] - for i in range(4): - # down_blocks have two resnets - for j in range(2): - hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' - sd_down_prefix = f'encoder.down.{i}.block.{j}.' - vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) - if i < 3: - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' - sd_downsample_prefix = f'down.{i}.downsample.' - vae_conversion_map.append( - (sd_downsample_prefix, hf_downsample_prefix) - ) + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = f'up.{3-i}.upsample.' - vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) - # up_blocks have three resnets - # also, up blocks in hf are numbered in reverse from sd - for j in range(3): - hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' - sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' - vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) - # this part accounts for mid blocks in both the encoder and the decoder - for i in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{i}.' - sd_mid_res_prefix = f'mid.block_{i+1}.' - vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) - vae_conversion_map_attn = [ - # (stable-diffusion, HF Diffusers) - ('norm.', 'group_norm.'), - ('q.', 'query.'), - ('k.', 'key.'), - ('v.', 'value.'), - ('proj_out.', 'proj_attn.'), - ] + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] - mapping = {k: k for k in vae_state_dict.keys()} - for k, v in mapping.items(): - for sd_part, hf_part in vae_conversion_map: - v = v.replace(hf_part, sd_part) - mapping[k] = v - for k, v in mapping.items(): - if 'attentions' in k: - for sd_part, hf_part in vae_conversion_map_attn: - v = v.replace(hf_part, sd_part) - mapping[k] = v - new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} - weights_to_convert = ['q', 'k', 'v', 'proj_out'] - for k, v in new_state_dict.items(): - for weight_name in weights_to_convert: - if f'mid.attn_1.{weight_name}.weight' in k: - # print(f"Reshaping {k} for SD format") - new_state_dict[k] = reshape_weight_for_sd(v) + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) - return new_state_dict + return new_state_dict # endregion # region 自作のモデル読み書きなど - def is_safetensors(path): - return os.path.splitext(path)[1].lower() == '.safetensors' + return os.path.splitext(path)[1].lower() == '.safetensors' def load_checkpoint_with_text_encoder_conversion(ckpt_path): - # text encoderの格納形式が違うモデルに対応する ('text_model'がない) - TEXT_ENCODER_KEY_REPLACEMENTS = [ - ( - 'cond_stage_model.transformer.embeddings.', - 'cond_stage_model.transformer.text_model.embeddings.', - ), - ( - 'cond_stage_model.transformer.encoder.', - 'cond_stage_model.transformer.text_model.encoder.', - ), - ( - 'cond_stage_model.transformer.final_layer_norm.', - 'cond_stage_model.transformer.text_model.final_layer_norm.', - ), - ] + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), + ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'), + ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.') + ] - if is_safetensors(ckpt_path): - checkpoint = None - state_dict = load_file(ckpt_path, 'cpu') + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path, "cpu") + else: + checkpoint = torch.load(ckpt_path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] else: - checkpoint = torch.load(ckpt_path, map_location='cpu') - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - else: - state_dict = checkpoint - checkpoint = None + state_dict = checkpoint + checkpoint = None - key_reps = [] - for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: - for key in state_dict.keys(): - if key.startswith(rep_from): - new_key = rep_to + key[len(rep_from) :] - key_reps.append((key, new_key)) + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from):] + key_reps.append((key, new_key)) - for key, new_key in key_reps: - state_dict[new_key] = state_dict[key] - del state_dict[key] + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] - return checkpoint, state_dict + return checkpoint, state_dict # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) - if dtype is not None: - for k, v in state_dict.items(): - if type(v) is torch.Tensor: - state_dict[k] = v.to(dtype) + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if dtype is not None: + for k, v in state_dict.items(): + if type(v) is torch.Tensor: + state_dict[k] = v.to(dtype) - # Convert the UNet2DConditionModel model. - unet_config = create_unet_diffusers_config(v2) - converted_unet_checkpoint = convert_ldm_unet_checkpoint( - v2, state_dict, unet_config + # Convert the UNet2DConditionModel model. + unet_config = create_unet_diffusers_config(v2) + converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config) + + unet = UNet2DConditionModel(**unet_config) + info = unet.load_state_dict(converted_unet_checkpoint) + print("loading u-net:", info) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config() + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + + vae = AutoencoderKL(**vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("loadint vae:", info) + + # convert text_model + if v2: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77) + cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=23, + num_attention_heads=16, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=512, + torch_dtype="float32", + transformers_version="4.25.0.dev0", ) + text_model = CLIPTextModel._from_config(cfg) + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + else: + converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict) + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + info = text_model.load_state_dict(converted_text_encoder_checkpoint) + print("loading text encoder:", info) - unet = UNet2DConditionModel(**unet_config) - info = unet.load_state_dict(converted_unet_checkpoint) - print('loading u-net:', info) - - # Convert the VAE model. - vae_config = create_vae_diffusers_config() - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - state_dict, vae_config - ) - - vae = AutoencoderKL(**vae_config) - info = vae.load_state_dict(converted_vae_checkpoint) - print('loadint vae:', info) - - # convert text_model - if v2: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2( - state_dict, 77 - ) - cfg = CLIPTextConfig( - vocab_size=49408, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=23, - num_attention_heads=16, - max_position_embeddings=77, - hidden_act='gelu', - layer_norm_eps=1e-05, - dropout=0.0, - attention_dropout=0.0, - initializer_range=0.02, - initializer_factor=1.0, - pad_token_id=1, - bos_token_id=0, - eos_token_id=2, - model_type='clip_text_model', - projection_dim=512, - torch_dtype='float32', - transformers_version='4.25.0.dev0', - ) - text_model = CLIPTextModel._from_config(cfg) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - else: - converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1( - state_dict - ) - text_model = CLIPTextModel.from_pretrained( - 'openai/clip-vit-large-patch14' - ) - info = text_model.load_state_dict(converted_text_encoder_checkpoint) - print('loading text encoder:', info) - - return text_model, vae, unet + return text_model, vae, unet -def convert_text_encoder_state_dict_to_sd_v2( - checkpoint, make_dummy_weights=False -): - def convert_key(key): - # position_idsの除去 - if '.position_ids' in key: - return None +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None - # common - key = key.replace('text_model.encoder.', 'transformer.') - key = key.replace('text_model.', '') - if 'layers' in key: - # resblocks conversion - key = key.replace('.layers.', '.resblocks.') - if '.layer_norm' in key: - key = key.replace('.layer_norm', '.ln_') - elif '.mlp.' in key: - key = key.replace('.fc1.', '.c_fc.') - key = key.replace('.fc2.', '.c_proj.') - elif '.self_attn.out_proj' in key: - key = key.replace('.self_attn.out_proj.', '.attn.out_proj.') - elif '.self_attn.' in key: - key = None # 特殊なので後で処理する - else: - raise ValueError(f'unexpected key in DiffUsers model: {key}') - elif '.position_embedding' in key: - key = key.replace( - 'embeddings.position_embedding.weight', 'positional_embedding' - ) - elif '.token_embedding' in key: - key = key.replace( - 'embeddings.token_embedding.weight', 'token_embedding.weight' - ) - elif 'final_layer_norm' in key: - key = key.replace('final_layer_norm', 'ln_final') - return key + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif '.self_attn.out_proj' in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif '.self_attn.' in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif '.position_embedding' in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif '.token_embedding' in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif 'final_layer_norm' in key: + key = key.replace("final_layer_norm", "ln_final") + return key - keys = list(checkpoint.keys()) - new_sd = {} + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if 'layers' in key and 'q_proj' in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) for key in keys: - new_key = convert_key(key) - if new_key is None: - continue - new_sd[new_key] = checkpoint[key] + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる - # attnの変換 - for key in keys: - if 'layers' in key and 'q_proj' in key: - # 三つを結合 - key_q = key - key_k = key.replace('q_proj', 'k_proj') - key_v = key.replace('q_proj', 'v_proj') + # Diffusersに含まれない重みを作っておく + new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd['logit_scale'] = torch.tensor(1) - value_q = checkpoint[key_q] - value_k = checkpoint[key_k] - value_v = checkpoint[key_v] - value = torch.cat([value_q, value_k, value_v]) - - new_key = key.replace( - 'text_model.encoder.layers.', 'transformer.resblocks.' - ) - new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_') - new_sd[new_key] = value - - # 最後の層などを捏造するか - if make_dummy_weights: - print( - 'make dummy weights for resblock.23, text_projection and logit scale.' - ) - keys = list(new_sd.keys()) - for key in keys: - if key.startswith('transformer.resblocks.22.'): - new_sd[key.replace('.22.', '.23.')] = new_sd[ - key - ].clone() # copyしないとsafetensorsの保存で落ちる - - # Diffusersに含まれない重みを作っておく - new_sd['text_projection'] = torch.ones( - (1024, 1024), - dtype=new_sd[keys[0]].dtype, - device=new_sd[keys[0]].device, - ) - new_sd['logit_scale'] = torch.tensor(1) - - return new_sd + return new_sd -def save_stable_diffusion_checkpoint( - v2, - output_file, - text_encoder, - unet, - ckpt_path, - epochs, - steps, - save_dtype=None, - vae=None, -): - if ckpt_path is not None: - # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む - checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion( - ckpt_path - ) - if checkpoint is None: # safetensors または state_dictのckpt - checkpoint = {} - strict = False - else: - strict = True - if 'state_dict' in state_dict: - del state_dict['state_dict'] +def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False else: - # 新しく作る - assert ( - vae is not None - ), 'VAE is required to save a checkpoint without a given checkpoint' - checkpoint = {} - state_dict = {} - strict = False + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False - def update_sd(prefix, sd): - for k, v in sd.items(): - key = prefix + k - assert ( - not strict or key in state_dict - ), f'Illegal key in save SD: {key}' - if save_dtype is not None: - v = v.detach().clone().to('cpu').to(save_dtype) - state_dict[key] = v + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v - # Convert the UNet model - unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) - update_sd('model.diffusion_model.', unet_state_dict) + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) - # Convert the text encoder model + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {'state_dict': state_dict} + + if 'epoch' in checkpoint: + epochs += checkpoint['epoch'] + if 'global_step' in checkpoint: + steps += checkpoint['global_step'] + + new_ckpt['epoch'] = epochs + new_ckpt['global_step'] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 if v2: - make_dummy = ( - ckpt_path is None - ) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる - text_enc_dict = convert_text_encoder_state_dict_to_sd_v2( - text_encoder.state_dict(), make_dummy - ) - update_sd('cond_stage_model.model.', text_enc_dict) + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 else: - text_enc_dict = text_encoder.state_dict() - update_sd('cond_stage_model.transformer.', text_enc_dict) + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - # Convert the VAE - if vae is not None: - vae_dict = convert_vae_state_dict(vae.state_dict()) - update_sd('first_stage_model.', vae_dict) + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") - # Put together new checkpoint - key_count = len(state_dict.keys()) - new_ckpt = {'state_dict': state_dict} - - if 'epoch' in checkpoint: - epochs += checkpoint['epoch'] - if 'global_step' in checkpoint: - steps += checkpoint['global_step'] - - new_ckpt['epoch'] = epochs - new_ckpt['global_step'] = steps - - if is_safetensors(output_file): - # TODO Tensor以外のdictの値を削除したほうがいいか - save_file(state_dict, output_file) - else: - torch.save(new_ckpt, output_file) - - return key_count + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) -def save_diffusers_checkpoint( - v2, - output_dir, - text_encoder, - unet, - pretrained_model_name_or_path, - vae=None, - use_safetensors=False, -): - if pretrained_model_name_or_path is None: - # load default settings for v1/v2 - if v2: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 - else: - pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 - - scheduler = DDIMScheduler.from_pretrained( - pretrained_model_name_or_path, subfolder='scheduler' - ) - tokenizer = CLIPTokenizer.from_pretrained( - pretrained_model_name_or_path, subfolder='tokenizer' - ) - if vae is None: - vae = AutoencoderKL.from_pretrained( - pretrained_model_name_or_path, subfolder='vae' - ) - - pipeline = StableDiffusionPipeline( - unet=unet, - text_encoder=text_encoder, - vae=vae, - scheduler=scheduler, - tokenizer=tokenizer, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=None, - ) - pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) - - -VAE_PREFIX = 'first_stage_model.' +VAE_PREFIX = "first_stage_model." def load_vae(vae_id, dtype): - print(f'load VAE: {vae_id}') - if os.path.isdir(vae_id) or not os.path.isfile(vae_id): - # Diffusers local/remote - try: - vae = AutoencoderKL.from_pretrained( - vae_id, subfolder=None, torch_dtype=dtype - ) - except EnvironmentError as e: - print(f'exception occurs in loading vae: {e}') - print("retry with subfolder='vae'") - vae = AutoencoderKL.from_pretrained( - vae_id, subfolder='vae', torch_dtype=dtype - ) - return vae - - # local - vae_config = create_vae_diffusers_config() - - if vae_id.endswith('.bin'): - # SD 1.5 VAE on Huggingface - vae_sd = torch.load(vae_id, map_location='cpu') - converted_vae_checkpoint = vae_sd - else: - # StableDiffusion - vae_model = torch.load(vae_id, map_location='cpu') - vae_sd = vae_model['state_dict'] - - # vae only or full model - full_model = False - for vae_key in vae_sd: - if vae_key.startswith(VAE_PREFIX): - full_model = True - break - if not full_model: - sd = {} - for key, value in vae_sd.items(): - sd[VAE_PREFIX + key] = value - vae_sd = sd - del sd - - # Convert the VAE model. - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - vae_sd, vae_config - ) - - vae = AutoencoderKL(**vae_config) - vae.load_state_dict(converted_vae_checkpoint) + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) return vae + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + vae_sd = torch.load(vae_id, map_location="cpu") + converted_vae_checkpoint = vae_sd + else: + # StableDiffusion + vae_model = torch.load(vae_id, map_location="cpu") + vae_sd = vae_model['state_dict'] + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + def get_epoch_ckpt_name(use_safetensors, epoch): - return f'epoch-{epoch:06d}' + ( - '.safetensors' if use_safetensors else '.ckpt' - ) + return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt") def get_last_ckpt_name(use_safetensors): - return f'last' + ('.safetensors' if use_safetensors else '.ckpt') + return f"last" + (".safetensors" if use_safetensors else ".ckpt") # endregion -def make_bucket_resolutions( - max_reso, min_size=256, max_size=1024, divisible=64 -): - max_width, max_height = max_reso - max_area = (max_width // divisible) * (max_height // divisible) +def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64): + max_width, max_height = max_reso + max_area = (max_width // divisible) * (max_height // divisible) - resos = set() + resos = set() - size = int(math.sqrt(max_area)) * divisible - resos.add((size, size)) + size = int(math.sqrt(max_area)) * divisible + resos.add((size, size)) - size = min_size - while size <= max_size: - width = size - height = min(max_size, (max_area // (width // divisible)) * divisible) - resos.add((width, height)) - resos.add((height, width)) + size = min_size + while size <= max_size: + width = size + height = min(max_size, (max_area // (width // divisible)) * divisible) + resos.add((width, height)) + resos.add((height, width)) - # # make additional resos - # if width >= height and width - divisible >= min_size: - # resos.add((width - divisible, height)) - # resos.add((height, width - divisible)) - # if height >= width and height - divisible >= min_size: - # resos.add((width, height - divisible)) - # resos.add((height - divisible, width)) + # # make additional resos + # if width >= height and width - divisible >= min_size: + # resos.add((width - divisible, height)) + # resos.add((height, width - divisible)) + # if height >= width and height - divisible >= min_size: + # resos.add((width, height - divisible)) + # resos.add((height - divisible, width)) - size += divisible + size += divisible - resos = list(resos) - resos.sort() + resos = list(resos) + resos.sort() - aspect_ratios = [w / h for w, h in resos] - return resos, aspect_ratios + aspect_ratios = [w / h for w, h in resos] + return resos, aspect_ratios if __name__ == '__main__': - resos, aspect_ratios = make_bucket_resolutions((512, 768)) - print(len(resos)) - print(resos) - print(aspect_ratios) + resos, aspect_ratios = make_bucket_resolutions((512, 768)) + print(len(resos)) + print(resos) + print(aspect_ratios) - ars = set() - for ar in aspect_ratios: - if ar in ars: - print('error! duplicate ar:', ar) - ars.add(ar) + ars = set() + for ar in aspect_ratios: + if ar in ars: + print("error! duplicate ar:", ar) + ars.add(ar) \ No newline at end of file diff --git a/train_network.py b/train_network.py index 3b2a8d1..cf171a7 100644 --- a/train_network.py +++ b/train_network.py @@ -925,11 +925,12 @@ def train(args): print(f"update token length: {args.max_token_length}") # 学習データを用意する + assert args.resolution is not None, f"resolution is required / resolution(解像度)を指定してください" resolution = tuple([int(r) for r in args.resolution.split(',')]) if len(resolution) == 1: resolution = (resolution[0], resolution[0]) assert len(resolution) == 2, \ - f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}" + f"resolution must be 'size' or 'width,height' / resolution(解像度)は'サイズ'または'幅','高さ'で指定してください: {args.resolution}" if args.face_crop_aug_range is not None: face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')]) @@ -1373,9 +1374,9 @@ if __name__ == '__main__': help="keep heading N tokens when shuffling caption tokens / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") - parser.add_argument("--in_json", type=str, default=None, help="json meatadata for dataset / データセットのmetadataのjsonファイル") + parser.add_argument("--in_json", type=str, default=None, help="json metadata for dataset / データセットのmetadataのjsonファイル") parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子") - parser.add_argument("--dataset_repeats", type=int, default=None, + parser.add_argument("--dataset_repeats", type=int, default=1, help="repeat dataset when training with captions / キャプションでの学習時にデータセットを繰り返す回数") parser.add_argument("--output_dir", type=str, default=None, help="directory to output trained model / 学習後のモデル出力先ディレクトリ")