add XL support for live previews: approx and TAESD
This commit is contained in:
parent
6f23da603d
commit
b8159d0919
@ -48,7 +48,7 @@ def extend_sdxl(model):
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||
|
||||
model.is_xl = True
|
||||
model.is_sdxl = True
|
||||
|
||||
|
||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
|
@ -2,9 +2,9 @@ import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
from modules import devices, paths, shared
|
||||
|
||||
sd_vae_approx_model = None
|
||||
sd_vae_approx_models = {}
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
@ -31,19 +31,34 @@ class VAEApprox(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading VAEApprox model to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
||||
loaded_model = sd_vae_approx_models.get(model_name)
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
|
||||
|
||||
return sd_vae_approx_model
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||
|
||||
loaded_model = VAEApprox()
|
||||
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_approx_models[model_name] = loaded_model
|
||||
|
||||
return loaded_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
|
@ -8,9 +8,9 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
from modules import devices, paths_internal, shared
|
||||
|
||||
sd_vae_taesd = None
|
||||
sd_vae_taesd_models = {}
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
@ -61,9 +61,7 @@ class TAESD(nn.Module):
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def download_model(model_path):
|
||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
@ -72,17 +70,19 @@ def download_model(model_path):
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_taesd
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||
download_model(model_path)
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
loaded_model = TAESD(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
||||
return loaded_model.decoder
|
||||
|
Loading…
Reference in New Issue
Block a user