diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 905cbeef..e493f366 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,47 +1,60 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat import modules.textual_inversion.dataset +import torch +import tqdm +from einops import rearrange, repeat +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish} - - def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): + activation_dict = { + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + } + + def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): super().__init__() assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - + assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + linears = [] for i in range(len(layer_structure) - 1): + + # Add a fully-connected layer linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - # if skip_first_layer because first parameters potentially contain negative values - # if i < 1: continue - if activation_func in HypernetworkModule.activation_dict: - linears.append(HypernetworkModule.activation_dict[activation_func]()) + + # Add an activation func + if activation_func == "linear": + pass + elif activation_func in self.activation_dict: + linears.append(self.activation_dict[activation_func]()) else: - print("Invalid key {} encountered as activation function!".format(activation_func)) - # if use_dropout: - # linears.append(torch.nn.Dropout(p=0.3)) + raise NotImplementedError( + "Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'" + ) + + # Add dropout + if use_dropout: + linears.append(torch.nn.Dropout(p=0.3)) + + # Add layer normalization if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) @@ -93,7 +106,7 @@ class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): self.filename = None self.name = name self.layers = {} @@ -101,13 +114,14 @@ class Hypernetwork: self.sd_checkpoint = None self.sd_checkpoint_name = None self.layer_structure = layer_structure - self.add_layer_norm = add_layer_norm self.activation_func = activation_func + self.add_layer_norm = add_layer_norm + self.use_dropout = use_dropout for size in enable_sizes or []: self.layers[size] = ( - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), - HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), + HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout), ) def weights(self): @@ -129,8 +143,9 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name state_dict['layer_structure'] = self.layer_structure - state_dict['is_layer_norm'] = self.add_layer_norm state_dict['activation_func'] = self.activation_func + state_dict['is_layer_norm'] = self.add_layer_norm + state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -144,8 +159,9 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - self.add_layer_norm = state_dict.get('is_layer_norm', False) self.activation_func = state_dict.get('activation_func', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) + self.use_dropout = state_dict.get('use_dropout', False) for size, sd in state_dict.items(): if type(size) == int: diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 1a5a27d8..5f6f17b6 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -3,14 +3,13 @@ import os import re import gradio as gr - -import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared, devices +import modules.textual_inversion.textual_inversion +from modules import devices, sd_hijack, shared from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None): +def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" @@ -21,8 +20,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm name=name, enable_sizes=[int(x) for x in enable_sizes], layer_structure=layer_structure, - add_layer_norm=add_layer_norm, activation_func=activation_func, + add_layer_norm=add_layer_norm, + use_dropout=use_dropout, ) hypernet.save(fn) diff --git a/modules/ui.py b/modules/ui.py index 716f14b8..d4b32c05 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,43 +5,44 @@ import json import math import mimetypes import os +import platform import random +import subprocess as sp import sys import tempfile import time import traceback -import platform -import subprocess as sp from functools import partial, reduce +import gradio as gr +import gradio.routes +import gradio.utils import numpy as np +import piexif import torch from PIL import Image, PngImagePlugin -import piexif -import gradio as gr -import gradio.utils -import gradio.routes - -from modules import sd_hijack, sd_models, localization +from modules import localization, sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts +from modules.shared import cmd_opts, opts, restricted_opts + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags -import modules.shared as shared -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.sd_hijack import model_hijack -import modules.ldsr_model -import modules.scripts -import modules.gfpgan_model + import modules.codeformer_model -import modules.styles import modules.generation_parameters_copypaste -from modules import prompt_parser -from modules.images import save_image -import modules.textual_inversion.ui +import modules.gfpgan_model import modules.hypernetworks.ui import modules.images_history as img_his +import modules.ldsr_model +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1223,8 +1224,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") + new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"]) new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"]) + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") with gr.Row(): with gr.Column(scale=3): @@ -1308,8 +1310,9 @@ def create_ui(wrap_gradio_gpu_call): new_hypernetwork_name, new_hypernetwork_sizes, new_hypernetwork_layer_structure, - new_hypernetwork_add_layer_norm, new_hypernetwork_activation_func, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout ], outputs=[ train_hypernetwork_name,