commit
0198eaec45
@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||
|
||||
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
|
||||
|
||||
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
|
||||
if 'mlp_fc1' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||
elif 'mlp_fc2' in m[1]:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
|
||||
else:
|
||||
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
|
||||
|
||||
return key
|
||||
|
||||
|
||||
@ -147,10 +155,20 @@ class LoraUpDownModule:
|
||||
def assign_lora_names_to_compvis_modules(sd_model):
|
||||
lora_layer_mapping = {}
|
||||
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
if shared.sd_model.is_sdxl:
|
||||
for i, embedder in enumerate(shared.sd_model.conditioner.embedders):
|
||||
if not hasattr(embedder, 'wrapped'):
|
||||
continue
|
||||
|
||||
for name, module in embedder.wrapped.named_modules():
|
||||
lora_name = f'{i}_{name.replace(".", "_")}'
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
else:
|
||||
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
lora_layer_mapping[lora_name] = module
|
||||
module.lora_layer_name = lora_name
|
||||
|
||||
for name, module in shared.sd_model.model.named_modules():
|
||||
lora_name = name.replace(".", "_")
|
||||
@ -173,10 +191,10 @@ def load_lora(name, lora_on_disk):
|
||||
keys_failed_to_match = {}
|
||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||
|
||||
for key_diffusers, weight in sd.items():
|
||||
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
|
||||
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
|
||||
for key_lora, weight in sd.items():
|
||||
key_lora_without_lora_parts, lora_key = key_lora.split(".", 1)
|
||||
|
||||
key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2)
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
@ -184,8 +202,16 @@ def load_lora(name, lora_on_disk):
|
||||
if m:
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
||||
|
||||
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"
|
||||
if sd_module is None and "lora_unet" in key_lora_without_lora_parts:
|
||||
key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model")
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts:
|
||||
key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model")
|
||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||
|
||||
if sd_module is None:
|
||||
keys_failed_to_match[key_diffusers] = key
|
||||
keys_failed_to_match[key_lora] = key
|
||||
continue
|
||||
|
||||
lora_module = lora.modules.get(key, None)
|
||||
@ -208,9 +234,9 @@ def load_lora(name, lora_on_disk):
|
||||
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||
else:
|
||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||
print(f'Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||
continue
|
||||
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||
raise AssertionError(f"Lora layer {key_lora} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||
|
||||
with torch.no_grad():
|
||||
module.weight.copy_(weight)
|
||||
@ -222,7 +248,7 @@ def load_lora(name, lora_on_disk):
|
||||
elif lora_key == "lora_down.weight":
|
||||
lora_module.down = module
|
||||
else:
|
||||
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||
raise AssertionError(f"Bad Lora layer name: {key_lora} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||
|
||||
if keys_failed_to_match:
|
||||
print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")
|
||||
|
@ -378,7 +378,7 @@ def apply_hypernetworks(hypernetworks, context, layer=None):
|
||||
return context_k, context_v
|
||||
|
||||
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
||||
def attention_CrossAttention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
|
@ -237,11 +237,13 @@ def prepare_environment():
|
||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||
|
||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||
stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")
|
||||
k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')
|
||||
codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')
|
||||
blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')
|
||||
|
||||
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")
|
||||
stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")
|
||||
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")
|
||||
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
|
||||
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
|
||||
@ -299,6 +301,7 @@ def prepare_environment():
|
||||
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
|
||||
|
||||
git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)
|
||||
git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash)
|
||||
git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
|
||||
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
|
||||
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
|
||||
@ -323,6 +326,7 @@ def prepare_environment():
|
||||
exit(0)
|
||||
|
||||
|
||||
|
||||
def configure_for_tests():
|
||||
if "--api" not in sys.argv:
|
||||
sys.argv.append("--api")
|
||||
|
@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
send_me_to_gpu(first_stage_model, None)
|
||||
return first_stage_model_decode(z)
|
||||
|
||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
||||
to_remain_in_cpu = [
|
||||
(sd_model, 'first_stage_model'),
|
||||
(sd_model, 'depth_model'),
|
||||
(sd_model, 'embedder'),
|
||||
(sd_model, 'model'),
|
||||
(sd_model, 'embedder'),
|
||||
]
|
||||
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
||||
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||
|
||||
if is_sdxl:
|
||||
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||
elif is_sd2:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||
else:
|
||||
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
||||
|
||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
||||
stored = []
|
||||
for obj, field in to_remain_in_cpu:
|
||||
module = getattr(obj, field, None)
|
||||
stored.append(module)
|
||||
setattr(obj, field, None)
|
||||
|
||||
# send the model to GPU.
|
||||
sd_model.to(devices.device)
|
||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
||||
|
||||
# put modules back. the modules will be in CPU.
|
||||
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
||||
setattr(obj, field, module)
|
||||
|
||||
# register hooks for those the first three models
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
if is_sdxl:
|
||||
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||
elif is_sd2:
|
||||
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
else:
|
||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||
@ -73,11 +100,9 @@ def setup_for_low_vram(sd_model, use_medvram):
|
||||
sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu)
|
||||
if sd_model.embedder:
|
||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
||||
del sd_model.cond_stage_model.transformer
|
||||
if hasattr(sd_model, 'cond_stage_model'):
|
||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||
|
||||
if use_medvram:
|
||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||
|
@ -5,6 +5,21 @@ from modules.paths_internal import models_path, script_path, data_path, extensio
|
||||
import modules.safe # noqa: F401
|
||||
|
||||
|
||||
def mute_sdxl_imports():
|
||||
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
module = Dummy()
|
||||
module.LPIPS = None
|
||||
sys.modules['taming.modules.losses.lpips'] = module
|
||||
|
||||
module = Dummy()
|
||||
module.StableDataModuleFromConfig = None
|
||||
sys.modules['sgm.data'] = module
|
||||
|
||||
|
||||
# data_path = cmd_opts_pre.data
|
||||
sys.path.insert(0, script_path)
|
||||
|
||||
@ -18,8 +33,11 @@ for possible_sd_path in possible_sd_paths:
|
||||
|
||||
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||
|
||||
mute_sdxl_imports()
|
||||
|
||||
path_dirs = [
|
||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||
@ -35,6 +53,13 @@ for d, must_exist, what, options in path_dirs:
|
||||
d = os.path.abspath(d)
|
||||
if "atstart" in options:
|
||||
sys.path.insert(0, d)
|
||||
elif "sgm" in options:
|
||||
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||
|
||||
sys.path.insert(0, d)
|
||||
import sgm # noqa: F401
|
||||
sys.path.pop(0)
|
||||
else:
|
||||
sys.path.append(d)
|
||||
paths[what] = d
|
||||
|
@ -330,8 +330,21 @@ class StableDiffusionProcessing:
|
||||
|
||||
caches is a list with items described above.
|
||||
"""
|
||||
|
||||
cached_params = (
|
||||
required_prompts,
|
||||
steps,
|
||||
opts.CLIP_stop_at_last_layers,
|
||||
shared.sd_model.sd_checkpoint_info,
|
||||
extra_network_data,
|
||||
opts.sdxl_crop_left,
|
||||
opts.sdxl_crop_top,
|
||||
self.width,
|
||||
self.height,
|
||||
)
|
||||
|
||||
for cache in caches:
|
||||
if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
|
||||
if cache[0] is not None and cached_params == cache[0]:
|
||||
return cache[1]
|
||||
|
||||
cache = caches[0]
|
||||
@ -339,14 +352,17 @@ class StableDiffusionProcessing:
|
||||
with devices.autocast():
|
||||
cache[1] = function(shared.sd_model, required_prompts, steps)
|
||||
|
||||
cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
|
||||
cache[0] = cached_params
|
||||
return cache[1]
|
||||
|
||||
def setup_conds(self):
|
||||
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
||||
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)
|
||||
|
||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||
|
||||
def parse_extra_network_prompts(self):
|
||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||
@ -523,8 +539,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
with devices.autocast(disable=x.dtype == devices.dtype_vae):
|
||||
x = model.decode_first_stage(x)
|
||||
x = model.decode_first_stage(x.to(devices.dtype_vae))
|
||||
|
||||
return x
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
@ -109,7 +111,25 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts, steps):
|
||||
class SdConditioning(list):
|
||||
"""
|
||||
A list with prompts for stable diffusion's conditioner model.
|
||||
Can also specify width and height of created image - SDXL needs it.
|
||||
"""
|
||||
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
|
||||
super().__init__()
|
||||
self.extend(prompts)
|
||||
|
||||
if copy_from is None:
|
||||
copy_from = prompts
|
||||
|
||||
self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
|
||||
self.width = width or getattr(copy_from, 'width', None)
|
||||
self.height = height or getattr(copy_from, 'height', None)
|
||||
|
||||
|
||||
|
||||
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||
and the sampling step at which this condition is to be replaced by the next one.
|
||||
|
||||
@ -139,12 +159,17 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
res.append(cached)
|
||||
continue
|
||||
|
||||
texts = [x[1] for x in prompt_schedule]
|
||||
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
|
||||
conds = model.get_learned_conditioning(texts)
|
||||
|
||||
cond_schedule = []
|
||||
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||
if isinstance(conds, dict):
|
||||
cond = {k: v[i] for k, v in conds.items()}
|
||||
else:
|
||||
cond = conds[i]
|
||||
|
||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))
|
||||
|
||||
cache[prompt] = cond_schedule
|
||||
res.append(cond_schedule)
|
||||
@ -155,11 +180,13 @@ def get_learned_conditioning(model, prompts, steps):
|
||||
re_AND = re.compile(r"\bAND\b")
|
||||
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||
|
||||
def get_multicond_prompt_list(prompts):
|
||||
|
||||
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||
res_indexes = []
|
||||
|
||||
prompt_flat_list = []
|
||||
prompt_indexes = {}
|
||||
prompt_flat_list = SdConditioning(prompts)
|
||||
prompt_flat_list.clear()
|
||||
|
||||
for prompt in prompts:
|
||||
subprompts = re_AND.split(prompt)
|
||||
@ -196,6 +223,7 @@ class MulticondLearnedConditioning:
|
||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||
|
||||
|
||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||
@ -214,20 +242,57 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
|
||||
return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)
|
||||
|
||||
|
||||
class DictWithShape(dict):
|
||||
def __init__(self, x, shape):
|
||||
super().__init__()
|
||||
self.update(x)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
is_dict = isinstance(param, dict)
|
||||
|
||||
if is_dict:
|
||||
dict_cond = param
|
||||
res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}
|
||||
res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)
|
||||
else:
|
||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||
|
||||
for i, cond_schedule in enumerate(c):
|
||||
target_index = 0
|
||||
for current, entry in enumerate(cond_schedule):
|
||||
if current_step <= entry.end_at_step:
|
||||
target_index = current
|
||||
break
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
if is_dict:
|
||||
for k, param in cond_schedule[target_index].cond.items():
|
||||
res[k][i] = param
|
||||
else:
|
||||
res[i] = cond_schedule[target_index].cond
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def stack_conds(tensors):
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors of different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
|
||||
return torch.stack(tensors)
|
||||
|
||||
|
||||
|
||||
def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
param = c.batch[0][0].schedules[0].cond
|
||||
|
||||
@ -249,16 +314,14 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
||||
|
||||
conds_list.append(conds_for_batch)
|
||||
|
||||
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
|
||||
# and won't be able to torch.stack them. So this fixes that.
|
||||
token_count = max([x.shape[0] for x in tensors])
|
||||
for i in range(len(tensors)):
|
||||
if tensors[i].shape[0] != token_count:
|
||||
last_vector = tensors[i][-1:]
|
||||
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
|
||||
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
|
||||
if isinstance(tensors[0], dict):
|
||||
keys = list(tensors[0].keys())
|
||||
stacked = {k: stack_conds([x[k] for x in tensors]) for k in keys}
|
||||
stacked = DictWithShape(stacked, stacked['crossattn'].shape)
|
||||
else:
|
||||
stacked = stack_conds(tensors).to(device=param.device, dtype=param.dtype)
|
||||
|
||||
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
|
||||
return conds_list, stacked
|
||||
|
||||
|
||||
re_attention = re.compile(r"""
|
||||
|
@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
|
||||
import ldm.models.diffusion.plms
|
||||
import ldm.modules.encoders.modules
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
import sgm.modules.diffusionmodules.openaimodel
|
||||
import sgm.modules.encoders.modules
|
||||
|
||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
@ -56,6 +61,9 @@ def apply_optimizations(option=None):
|
||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
||||
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||
|
||||
if current_optimizer is not None:
|
||||
current_optimizer.undo()
|
||||
current_optimizer = None
|
||||
@ -89,6 +97,10 @@ def undo_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
def fix_checkpoint():
|
||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||
@ -168,6 +180,32 @@ class StableDiffusionModelHijack:
|
||||
undo_optimizations()
|
||||
|
||||
def hijack(self, m):
|
||||
conditioner = getattr(m, 'conditioner', None)
|
||||
if conditioner:
|
||||
text_cond_models = []
|
||||
|
||||
for i in range(len(conditioner.embedders)):
|
||||
embedder = conditioner.embedders[i]
|
||||
typename = type(embedder).__name__
|
||||
if typename == 'FrozenOpenCLIPEmbedder':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenCLIPEmbedder':
|
||||
model_embeddings = embedder.transformer.text_model.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||
text_cond_models.append(conditioner.embedders[i])
|
||||
|
||||
if len(text_cond_models) == 1:
|
||||
m.cond_stage_model = text_cond_models[0]
|
||||
else:
|
||||
m.cond_stage_model = conditioner
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
|
@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||
self.chunk_length = 75
|
||||
|
||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||
self.legacy_ucg_val = None
|
||||
|
||||
def empty_chunk(self):
|
||||
"""creates an empty PromptChunk and returns it"""
|
||||
|
||||
@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
"""
|
||||
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||
An example shape returned by this function can be: (2, 77, 768).
|
||||
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
@ -242,7 +247,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
if hashes:
|
||||
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
|
||||
|
||||
return torch.hstack(zs)
|
||||
if getattr(self.wrapped, 'return_pooled', False):
|
||||
return torch.hstack(zs), zs[0].pooled
|
||||
else:
|
||||
return torch.hstack(zs)
|
||||
|
||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||
"""
|
||||
@ -265,9 +273,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
z *= (original_mean / new_mean)
|
||||
|
||||
return z
|
||||
|
||||
@ -324,3 +332,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")
|
||||
|
||||
if self.wrapped.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
else:
|
||||
z = outputs.hidden_states[self.wrapped.layer_idx]
|
||||
|
||||
return z
|
||||
|
@ -32,6 +32,40 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||
def __init__(self, wrapped, hijack):
|
||||
super().__init__(wrapped, hijack)
|
||||
|
||||
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||
self.id_pad = 0
|
||||
|
||||
def tokenize(self, texts):
|
||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||
|
||||
tokenized = [tokenizer.encode(text) for text in texts]
|
||||
|
||||
return tokenized
|
||||
|
||||
def encode_with_transformers(self, tokens):
|
||||
d = self.wrapped.encode_with_transformer(tokens)
|
||||
z = d[self.wrapped.layer]
|
||||
|
||||
pooled = d.get("pooled")
|
||||
if pooled is not None:
|
||||
z.pooled = pooled
|
||||
|
||||
return z
|
||||
|
||||
def encode_embedding_init_text(self, init_text, nvpt):
|
||||
ids = tokenizer.encode(init_text)
|
||||
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)
|
||||
|
||||
return embedded
|
||||
|
@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
|
||||
import sgm.modules.attention
|
||||
import sgm.modules.diffusionmodules.model
|
||||
|
||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||
|
||||
|
||||
class SdOptimization:
|
||||
@ -39,6 +43,9 @@ class SdOptimization:
|
||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||
|
||||
|
||||
class SdOptimizationXformers(SdOptimization):
|
||||
name = "xformers"
|
||||
@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdpNoMem(SdOptimization):
|
||||
@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationSubQuad(SdOptimization):
|
||||
@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||
|
||||
|
||||
class SdOptimizationV1(SdOptimization):
|
||||
@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
|
||||
cmd_opt = "opt_split_attention_v1"
|
||||
priority = 10
|
||||
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||
|
||||
|
||||
class SdOptimizationInvokeAI(SdOptimization):
|
||||
@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
||||
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||
|
||||
|
||||
class SdOptimizationDoggettx(SdOptimization):
|
||||
@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
|
||||
def apply(self):
|
||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||
|
||||
|
||||
def list_optimizers(res):
|
||||
@ -155,7 +173,7 @@ def get_available_vram():
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
|
||||
|
||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q_in = self.to_q(x)
|
||||
@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
|
||||
|
||||
def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
|
||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||
return r
|
||||
|
||||
|
||||
def einsum_op_mps_v1(q, k, v):
|
||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||
return einsum_op_compvis(q, k, v)
|
||||
@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
|
||||
slice_size -= 1
|
||||
return einsum_op_slice_1(q, k, v, slice_size)
|
||||
|
||||
|
||||
def einsum_op_mps_v2(q, k, v):
|
||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||
return einsum_op_compvis(q, k, v)
|
||||
else:
|
||||
return einsum_op_slice_0(q, k, v, 1)
|
||||
|
||||
|
||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
|
||||
def einsum_op_cuda(q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
|
||||
def einsum_op(q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return einsum_op_cuda(q, k, v)
|
||||
@ -328,7 +354,8 @@ def einsum_op(q, k, v):
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||
|
||||
h = self.heads
|
||||
@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = q.shape
|
||||
@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
||||
return None
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
context = default(context, x)
|
||||
@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
batch_size, sequence_length, inner_dim = x.shape
|
||||
|
||||
if mask is not None:
|
||||
@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||
|
||||
|
||||
def cross_attention_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
|
||||
return h3
|
||||
|
||||
|
||||
def xformers_attnblock_forward(self, x):
|
||||
try:
|
||||
h_ = x
|
||||
@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sdp_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
||||
|
||||
def sdp_no_mem_attnblock_forward(self, x):
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||
return sdp_attnblock_forward(self, x)
|
||||
|
||||
|
||||
def sub_quad_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
|
@ -14,7 +14,7 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
|
||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
@ -289,6 +289,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
if state_dict is None:
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
del state_dict
|
||||
timer.record("apply weights to model")
|
||||
@ -334,7 +338,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
model.sd_checkpoint_info = checkpoint_info
|
||||
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
|
||||
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
if hasattr(model, 'logvar'):
|
||||
model.logvar = model.logvar.to(devices.device) # fix for training
|
||||
|
||||
sd_vae.delete_base_vae()
|
||||
sd_vae.clear_loaded_vae()
|
||||
@ -391,10 +396,11 @@ def repair_config(sd_config):
|
||||
if not hasattr(sd_config.model.params, "use_ema"):
|
||||
sd_config.model.params.use_ema = False
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
if shared.cmd_opts.no_half:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||
elif shared.cmd_opts.upcast_sampling:
|
||||
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||
|
||||
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
|
||||
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
|
||||
@ -407,6 +413,8 @@ def repair_config(sd_config):
|
||||
|
||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
|
||||
|
||||
|
||||
class SdModelData:
|
||||
@ -441,6 +449,15 @@ class SdModelData:
|
||||
model_data = SdModelData()
|
||||
|
||||
|
||||
def get_empty_cond(sd_model):
|
||||
if hasattr(sd_model, 'conditioner'):
|
||||
d = sd_model.get_learned_conditioning([""])
|
||||
return d['crossattn']
|
||||
else:
|
||||
return sd_model.cond_stage_model([""])
|
||||
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import lowvram, sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@ -461,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
|
||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
||||
clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
|
||||
|
||||
timer.record("find config")
|
||||
|
||||
@ -513,7 +530,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer.record("scripts callbacks")
|
||||
|
||||
with devices.autocast(), torch.no_grad():
|
||||
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
||||
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||
|
||||
timer.record("calculate empty prompt")
|
||||
|
||||
|
@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
|
||||
|
||||
sd_configs_path = shared.sd_configs_path
|
||||
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
|
||||
sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
|
||||
|
||||
|
||||
config_default = shared.sd_default_config
|
||||
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||
config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
|
||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||
@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
|
||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||
|
||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl
|
||||
if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
|
||||
return config_sdxl_refiner
|
||||
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||
return config_depth_model
|
||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||
return config_unclip
|
||||
|
99
modules/sd_models_xl.py
Normal file
99
modules/sd_models_xl.py
Normal file
@ -0,0 +1,99 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import sgm.models.diffusion
|
||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||
import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
width = getattr(self, 'target_width', 1024)
|
||||
height = getattr(self, 'target_height', 1024)
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
"original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),
|
||||
"target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),
|
||||
"aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),
|
||||
}
|
||||
|
||||
force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)
|
||||
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])
|
||||
|
||||
return c
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||
return self.model(x, t, cond)
|
||||
|
||||
|
||||
def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility
|
||||
return x
|
||||
|
||||
|
||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||
sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding
|
||||
|
||||
|
||||
def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):
|
||||
res = []
|
||||
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:
|
||||
encoded = embedder.encode_embedding_init_text(init_text, nvpt)
|
||||
res.append(encoded)
|
||||
|
||||
return torch.cat(res, dim=1)
|
||||
|
||||
|
||||
def process_texts(self, texts):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
|
||||
return embedder.process_texts(texts)
|
||||
|
||||
|
||||
def get_target_prompt_token_count(self, token_count):
|
||||
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:
|
||||
return embedder.get_target_prompt_token_count(token_count)
|
||||
|
||||
|
||||
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
|
||||
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
|
||||
sgm.modules.GeneralConditioner.process_texts = process_texts
|
||||
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
|
||||
|
||||
|
||||
def extend_sdxl(model):
|
||||
"""this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""
|
||||
|
||||
dtype = next(model.model.diffusion_model.parameters()).dtype
|
||||
model.model.diffusion_model.dtype = dtype
|
||||
model.model.conditioning_key = 'crossattn'
|
||||
model.cond_stage_key = 'txt'
|
||||
# model.cond_stage_model will be set in sd_hijack
|
||||
|
||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||
|
||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||
|
||||
model.conditioner.wrapped = torch.nn.Module()
|
||||
|
||||
|
||||
sgm.modules.attention.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||
sgm.modules.encoders.modules.print = lambda *args: None
|
||||
|
||||
# this gets the code to load the vanilla attention that we override
|
||||
sgm.modules.attention.SDP_IS_AVAILABLE = True
|
||||
sgm.modules.attention.XFORMERS_IS_AVAILABLE = False
|
@ -28,6 +28,9 @@ def create_sampler(name, model):
|
||||
|
||||
assert config is not None, f'bad sampler name: {name}'
|
||||
|
||||
if model.is_sdxl and config.options.get("no_sdxl", False):
|
||||
raise Exception(f"Sampler {config.name} is not supported for SDXL")
|
||||
|
||||
sampler = config.constructor(model)
|
||||
sampler.config = config
|
||||
|
||||
|
@ -11,9 +11,9 @@ import modules.models.diffusion.uni_pc
|
||||
|
||||
|
||||
samplers_data_compvis = [
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}),
|
||||
sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}),
|
||||
sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}),
|
||||
]
|
||||
|
||||
|
||||
|
@ -53,6 +53,28 @@ k_diffusion_scheduler = {
|
||||
}
|
||||
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
@ -105,10 +127,13 @@ class CFGDenoiser(torch.nn.Module):
|
||||
|
||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
@ -140,28 +165,28 @@ class CFGDenoiser(torch.nn.Module):
|
||||
num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]
|
||||
|
||||
if num_repeats < 0:
|
||||
tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1)
|
||||
tensor = pad_cond(tensor, -num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
elif num_repeats > 0:
|
||||
uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = torch.cat([tensor, uncond, uncond])
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = torch.cat([tensor, uncond])
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
if shared.batch_cond_uncond:
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict([cond_in[a:b]], image_cond_in[a:b]))
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||
@ -170,14 +195,14 @@ class CFGDenoiser(torch.nn.Module):
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = [tensor[a:b]]
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
|
@ -2,9 +2,9 @@ import os
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from modules import devices, paths
|
||||
from modules import devices, paths, shared
|
||||
|
||||
sd_vae_approx_model = None
|
||||
sd_vae_approx_models = {}
|
||||
|
||||
|
||||
class VAEApprox(nn.Module):
|
||||
@ -31,30 +31,55 @@ class VAEApprox(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
print(f'Downloading VAEApprox model to: {model_path}')
|
||||
torch.hub.download_url_to_file(model_url, model_path)
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_approx_model
|
||||
model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
|
||||
loaded_model = sd_vae_approx_models.get(model_name)
|
||||
|
||||
if sd_vae_approx_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model = VAEApprox()
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
|
||||
sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
sd_vae_approx_model.eval()
|
||||
sd_vae_approx_model.to(devices.device, devices.dtype)
|
||||
model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)
|
||||
|
||||
return sd_vae_approx_model
|
||||
if not os.path.exists(model_path):
|
||||
model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
|
||||
download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)
|
||||
|
||||
loaded_model = VAEApprox()
|
||||
loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_approx_models[model_name] = loaded_model
|
||||
|
||||
return loaded_model
|
||||
|
||||
|
||||
def cheap_approximation(sample):
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
|
||||
|
||||
coefs = torch.tensor([
|
||||
[0.298, 0.207, 0.208],
|
||||
[0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]).to(sample.device)
|
||||
if shared.sd_model.is_sdxl:
|
||||
coeffs = [
|
||||
[ 0.3448, 0.4168, 0.4395],
|
||||
[-0.1953, -0.0290, 0.0250],
|
||||
[ 0.1074, 0.0886, -0.0163],
|
||||
[-0.3730, -0.2499, -0.2088],
|
||||
]
|
||||
else:
|
||||
coeffs = [
|
||||
[ 0.298, 0.207, 0.208],
|
||||
[ 0.187, 0.286, 0.173],
|
||||
[-0.158, 0.189, 0.264],
|
||||
[-0.184, -0.271, -0.473],
|
||||
]
|
||||
|
||||
coefs = torch.tensor(coeffs).to(sample.device)
|
||||
|
||||
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
|
||||
|
||||
|
@ -8,9 +8,9 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modules import devices, paths_internal
|
||||
from modules import devices, paths_internal, shared
|
||||
|
||||
sd_vae_taesd = None
|
||||
sd_vae_taesd_models = {}
|
||||
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
@ -61,9 +61,7 @@ class TAESD(nn.Module):
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
|
||||
def download_model(model_path):
|
||||
model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
|
||||
|
||||
def download_model(model_path, model_url):
|
||||
if not os.path.exists(model_path):
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
@ -72,17 +70,19 @@ def download_model(model_path):
|
||||
|
||||
|
||||
def model():
|
||||
global sd_vae_taesd
|
||||
model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
|
||||
loaded_model = sd_vae_taesd_models.get(model_name)
|
||||
|
||||
if sd_vae_taesd is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
|
||||
download_model(model_path)
|
||||
if loaded_model is None:
|
||||
model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
|
||||
download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
|
||||
|
||||
if os.path.exists(model_path):
|
||||
sd_vae_taesd = TAESD(model_path)
|
||||
sd_vae_taesd.eval()
|
||||
sd_vae_taesd.to(devices.device, devices.dtype)
|
||||
loaded_model = TAESD(model_path)
|
||||
loaded_model.eval()
|
||||
loaded_model.to(devices.device, devices.dtype)
|
||||
sd_vae_taesd_models[model_name] = loaded_model
|
||||
else:
|
||||
raise FileNotFoundError('TAESD model not found')
|
||||
|
||||
return sd_vae_taesd.decoder
|
||||
return loaded_model.decoder
|
||||
|
@ -429,9 +429,16 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
|
||||
"sdxl_crop_top": OptionInfo(0, "crop top coordinate"),
|
||||
"sdxl_crop_left": OptionInfo(0, "crop left coordinate"),
|
||||
"sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"),
|
||||
"sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||
"cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
|
||||
"token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
|
||||
"token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
"token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
|
||||
|
@ -14,6 +14,7 @@ kornia
|
||||
lark
|
||||
numpy
|
||||
omegaconf
|
||||
open-clip-torch
|
||||
|
||||
piexif
|
||||
psutil
|
||||
|
@ -15,6 +15,7 @@ kornia==0.6.7
|
||||
lark==1.1.2
|
||||
numpy==1.23.5
|
||||
omegaconf==2.2.3
|
||||
open-clip-torch==2.20.0
|
||||
piexif==1.1.3
|
||||
psutil~=5.9.5
|
||||
pytorch_lightning==1.9.4
|
||||
|
Loading…
Reference in New Issue
Block a user