--no-half-vae

This commit is contained in:
AUTOMATIC 2022-10-10 16:11:14 +03:00
parent a357823339
commit 7349088d32
5 changed files with 20 additions and 5 deletions

View File

@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16
def randn(seed, shape): def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
@ -59,9 +60,12 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def autocast(): def autocast(disable=False):
from modules import shared from modules import shared
if disable:
return contextlib.nullcontext()
if dtype == torch.float32 or shared.cmd_opts.precision == "full": if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext() return contextlib.nullcontext()

View File

@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x return x
def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
x = model.decode_first_stage(x)
return x
def get_fixed_seed(seed): def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1: if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294)) return int(random.randrange(4294967294))
@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
samples_ddim = samples_ddim.to(devices.dtype) samples_ddim = samples_ddim.to(devices.dtype)
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
del samples_ddim del samples_ddim
@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scale_latent: if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else: else:
decoded_samples = self.sd_model.decode_first_stage(samples) decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear")

View File

@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info):
model.half() model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if os.path.exists(vae_file): if os.path.exists(vae_file):
@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info

View File

@ -7,7 +7,7 @@ import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
from modules import prompt_parser from modules import prompt_parser, devices, processing
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None):
def sample_to_image(samples): def sample_to_image(samples):
x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)

View File

@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")