diff --git a/webui.py b/webui.py index 0e4fb399..0fad34df 100644 --- a/webui.py +++ b/webui.py @@ -2,6 +2,8 @@ import argparse import os import sys from collections import namedtuple +from contextlib import nullcontext + import torch import torch.nn as nn import numpy as np @@ -51,6 +53,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") +parser.add_argument("--lowvram", action='store_true', help="enamble optimizations for low vram") cmd_opts = parser.parse_args() @@ -185,11 +188,80 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() model.eval() return model +module_in_gpu = None + + +def setup_for_low_vram(sd_model): + parents = {} + + def send_me_to_gpu(module, _): + """send this module to GPU; send whatever tracked module was previous in GPU to CPU; + we add this as forward_pre_hook to a lot of modules and this way all but one of them will + be in CPU + """ + global module_in_gpu + + module = parents.get(module, module) + + if module_in_gpu == module: + return + + if module_in_gpu is not None: + print('removing from gpu:', type(module_in_gpu)) + module_in_gpu.to(cpu) + + print('adding to gpu:', type(module)) + module.to(gpu) + + print('added to gpu:', type(module)) + module_in_gpu = module + + # see below for register_forward_pre_hook; + # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is + # useless here, and we just replace those methods + def first_stage_model_encode_wrap(self, encoder, x): + send_me_to_gpu(self, None) + return encoder(x) + + def first_stage_model_decode_wrap(self, decoder, z): + send_me_to_gpu(self, None) + return decoder(z) + + # remove three big modules, cond, first_stage, and unet from the model and then + # send the model to GPU. Then put modules back. the modules will be in CPU. + stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model + sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None + sd_model.to(device) + sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored + + # register hooks for those the first two models + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) + sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) + parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model + + # the third remaining model is still too big for 4GB, so we also do the same for its submodules + # so that only one of them is in GPU at a time + diff_model = sd_model.model.diffusion_model + stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed + diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None + sd_model.model.to(device) + diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored + + # install hooks for bits of third model + diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) + for block in diff_model.input_blocks: + block.register_forward_pre_hook(send_me_to_gpu) + diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) + for block in diff_model.output_blocks: + block.register_forward_pre_hook(send_me_to_gpu) + + def create_random_tensors(shape, seeds): xs = [] for seed in seeds: @@ -838,7 +910,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model) output_images = [] - with torch.no_grad(), autocast("cuda"), model.ema_scope(): + ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope) + with torch.no_grad(), autocast("cuda"), ema_scope(): p.init() for n in range(p.n_iter): @@ -1327,8 +1400,17 @@ interfaces = [ sd_config = OmegaConf.load(cmd_opts.config) sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) -device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") -sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device) +cpu = torch.device("cpu") +gpu = torch.device("cuda") +device = gpu if torch.cuda.is_available() else cpu + +sd_model = (sd_model if cmd_opts.no_half else sd_model.half()) + +if not cmd_opts.lowvram: + sd_model = sd_model.to(device) + +else: + setup_for_low_vram(sd_model) model_hijack = StableDiffusionModelHijack() model_hijack.hijack(sd_model)