diff --git a/.gitignore b/.gitignore index 69785b3e..4794865c 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion +/hypernetwork diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4905710e..cadb9911 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,52 +1,98 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat import modules.textual_inversion.dataset +import torch +import tqdm +from einops import rearrange, repeat +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum + + +def parse_layer_structure(dim, state_dict): + i = 0 + res = [1] + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + res.append(len(weight) // dim) + i += 1 + return res class HypernetworkModule(torch.nn.Module): multiplier = 1.0 + layer_structure = None + add_layer_norm = False def __init__(self, dim, state_dict=None): super().__init__() + if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None: + layer_structure = (1, 2, 1) + else: + if self.layer_structure is not None: + assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" + layer_structure = self.layer_structure + else: + layer_structure = parse_layer_structure(dim, state_dict) - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) + linears = [] + for i in range(len(layer_structure) - 1): + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if self.add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - self.load_state_dict(state_dict, strict=True) + try: + self.load_state_dict(state_dict) + except RuntimeError: + self.try_load_previous(state_dict) else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() + for layer in self.linear: + layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.bias.data.zero_() self.to(devices.device) + def try_load_previous(self, state_dict): + states = self.state_dict() + states['linear.0.bias'].copy_(state_dict['linear1.bias']) + states['linear.0.weight'].copy_(state_dict['linear1.weight']) + states['linear.1.bias'].copy_(state_dict['linear2.bias']) + states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def forward(self, x): - return x + (self.linear2(self.linear1(x))) * self.multiplier + return x + self.linear(x) * self.multiplier + + def trainables(self): + res = [] + for layer in self.linear: + res += [layer.weight, layer.bias] + return res def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +def apply_layer_structure(value=None): + HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure + + +def apply_layer_norm(value=None): + HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm + + class Hypernetwork: filename = None name = None @@ -68,7 +114,7 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + res += layer.trainables() return res @@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) - + assert ds.length > 1, "Dataset should contain more than 1 images" if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) -# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + c = torch.vstack([entry.cond for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] del x @@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", - "learn_rate": scheduler.learn_rate + "learn_rate": f"{scheduler.learn_rate:.7f}" }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: diff --git a/modules/shared.py b/modules/shared.py index c0d87168..c87ce70e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_models, sd_samplers, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}), + "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/webui.py b/webui.py index fe0ce321..86e98ad0 100644 --- a/webui.py +++ b/webui.py @@ -86,11 +86,13 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure) + shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm) def webui(): initialize() - + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -101,7 +103,7 @@ def webui(): while 1: demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - + app, local_url, share_url = demo.launch( share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None,