support for generating images on video cards with 4GB

This commit is contained in:
AUTOMATIC 2022-08-29 01:58:15 +03:00
parent 7a7a3a6b19
commit 9c9f048b5e

View File

@ -2,6 +2,8 @@ import argparse
import os import os
import sys import sys
from collections import namedtuple from collections import namedtuple
from contextlib import nullcontext
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
@ -51,6 +53,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--lowvram", action='store_true', help="enamble optimizations for low vram")
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
@ -185,11 +188,80 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:") print("unexpected keys:")
print(u) print(u)
model.cuda()
model.eval() model.eval()
return model return model
module_in_gpu = None
def setup_for_low_vram(sd_model):
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
print('removing from gpu:', type(module_in_gpu))
module_in_gpu.to(cpu)
print('adding to gpu:', type(module))
module.to(gpu)
print('added to gpu:', type(module))
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
# the third remaining model is still too big for 4GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
diff_model = sd_model.model.diffusion_model
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
def create_random_tensors(shape, seeds): def create_random_tensors(shape, seeds):
xs = [] xs = []
for seed in seeds: for seed in seeds:
@ -838,7 +910,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = [] output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope(): ema_scope = (nullcontext if cmd_opts.lowvram else model.ema_scope)
with torch.no_grad(), autocast("cuda"), ema_scope():
p.init() p.init()
for n in range(p.n_iter): for n in range(p.n_iter):
@ -1327,8 +1400,17 @@ interfaces = [
sd_config = OmegaConf.load(cmd_opts.config) sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt) sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") cpu = torch.device("cpu")
sd_model = (sd_model if cmd_opts.no_half else sd_model.half()).to(device) gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
if not cmd_opts.lowvram:
sd_model = sd_model.to(device)
else:
setup_for_low_vram(sd_model)
model_hijack = StableDiffusionModelHijack() model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model) model_hijack.hijack(sd_model)