custom unet support

This commit is contained in:
AUTOMATIC 2023-05-27 15:47:33 +03:00
parent a6e653be26
commit 339b531570
8 changed files with 148 additions and 8 deletions

View File

@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
sd_unet.apply_unet()
if state.job_count == -1:
state.job_count = p.n_iter

View File

@ -111,6 +111,7 @@ callback_map = dict(
callbacks_before_ui=[],
callbacks_on_reload=[],
callbacks_list_optimizers=[],
callbacks_list_unets=[],
)
@ -271,6 +272,18 @@ def list_optimizers_callback():
return res
def list_unets_callback():
res = []
for c in callback_map['callbacks_list_unets']:
try:
c.callback(res)
except Exception:
report_exception(c, 'list_unets')
return res
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@ -430,3 +443,10 @@ def on_list_optimizers(callback):
to it."""
add_callback(callback_map['callbacks_list_optimizers'], callback)
def on_list_unets(callback):
"""register a function to be called when UI is making a list of alternative options for unet.
The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""
add_callback(callback_map['callbacks_list_unets'], callback)

View File

@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType
import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@ -43,7 +43,7 @@ def list_optimizers():
optimizers.extend(new_optimizers)
def apply_optimizations():
def apply_optimizations(option=None):
global current_optimizer
undo_optimizations()
@ -60,7 +60,7 @@ def apply_optimizations():
current_optimizer.undo()
current_optimizer = None
selection = shared.opts.cross_attention_optimization
selection = option or shared.opts.cross_attention_optimization
if selection == "Automatic" and len(optimizers) > 0:
matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
else:
@ -72,12 +72,13 @@ def apply_optimizations():
matching_optimizer = optimizers[0]
if matching_optimizer is not None:
print(f"Applying optimization: {matching_optimizer.name}... ", end='')
print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
matching_optimizer.apply()
print("done.")
current_optimizer = matching_optimizer
return current_optimizer.name
else:
print("Disabling attention optimization")
return ''
@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self):
def apply_optimizations(self, option=None):
try:
self.optimization_method = apply_optimizations()
self.optimization_method = apply_optimizations(option)
except Exception as e:
errors.display(e, "applying cross attention optimization")
undo_optimizations()
@ -194,6 +195,11 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
self.layers = None
self.clip = None
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
def apply_circular(self, enable):
if self.circular_enabled == enable:
return

View File

@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@ -532,6 +532,8 @@ def reload_model_weights(sd_model=None, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
sd_unet.apply_unet("None")
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:

92
modules/sd_unet.py Normal file
View File

@ -0,0 +1,92 @@
import torch.nn
import ldm.modules.diffusionmodules.openaimodel
from modules import script_callbacks, shared, devices
unet_options = []
current_unet_option = None
current_unet = None
def list_unets():
new_unets = script_callbacks.list_unets_callback()
unet_options.clear()
unet_options.extend(new_unets)
def get_unet_option(option=None):
option = option or shared.opts.sd_unet
if option == "None":
return None
if option == "Automatic":
name = shared.sd_model.sd_checkpoint_info.model_name
options = [x for x in unet_options if x.model_name == name]
option = options[0].label if options else "None"
return next(iter([x for x in unet_options if x.label == option]), None)
def apply_unet(option=None):
global current_unet_option
global current_unet
new_option = get_unet_option(option)
if new_option == current_unet_option:
return
if current_unet is not None:
print(f"Dectivating unet: {current_unet.option.label}")
current_unet.deactivate()
current_unet_option = new_option
if current_unet_option is None:
current_unet = None
if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
shared.sd_model.model.diffusion_model.to(devices.device)
return
shared.sd_model.model.diffusion_model.to(devices.cpu)
devices.torch_gc()
current_unet = current_unet_option.create_unet()
current_unet.option = current_unet_option
print(f"Activating unet: {current_unet.option.label}")
current_unet.activate()
class SdUnetOption:
model_name = None
"""name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""
label = None
"""name of the unet in UI"""
def create_unet(self):
"""returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
raise NotImplementedError()
class SdUnet(torch.nn.Module):
def forward(self, x, timesteps, context, *args, **kwargs):
raise NotImplementedError()
def activate(self):
pass
def deactivate(self):
pass
def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)
return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)

View File

@ -403,6 +403,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),

View File

@ -29,3 +29,14 @@ def cross_attention_optimizations():
return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"]
def sd_unet_items():
import modules.sd_unet
return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"]
def refresh_unet_list():
import modules.sd_unet
modules.sd_unet.list_unets()

View File

@ -58,6 +58,7 @@ import modules.sd_hijack
import modules.sd_hijack_optimizations
import modules.sd_models
import modules.sd_vae
import modules.sd_unet
import modules.txt2img
import modules.script_callbacks
import modules.textual_inversion.textual_inversion
@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False):
modules.sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
modules.sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.