SDXL support
This commit is contained in:
parent
af081211ee
commit
da464a3fb3
@ -224,6 +224,20 @@ def run_extensions_installers(settings_file):
|
|||||||
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
run_extension_installer(os.path.join(extensions_dir, dirname_extension))
|
||||||
|
|
||||||
|
|
||||||
|
def mute_sdxl_imports():
|
||||||
|
"""create fake modules that SDXL wants to import but doesn't actually use for our purposes"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None))
|
||||||
|
module.LPIPS = None
|
||||||
|
sys.modules['taming.modules.losses.lpips'] = module
|
||||||
|
|
||||||
|
module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None))
|
||||||
|
module.StableDataModuleFromConfig = None
|
||||||
|
sys.modules['sgm.data'] = module
|
||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
@ -319,11 +333,14 @@ def prepare_environment():
|
|||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
|
mute_sdxl_imports()
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def configure_for_tests():
|
def configure_for_tests():
|
||||||
if "--api" not in sys.argv:
|
if "--api" not in sys.argv:
|
||||||
sys.argv.append("--api")
|
sys.argv.append("--api")
|
||||||
|
@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
send_me_to_gpu(first_stage_model, None)
|
send_me_to_gpu(first_stage_model, None)
|
||||||
return first_stage_model_decode(z)
|
return first_stage_model_decode(z)
|
||||||
|
|
||||||
# for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field
|
to_remain_in_cpu = [
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
(sd_model, 'first_stage_model'),
|
||||||
sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model
|
(sd_model, 'depth_model'),
|
||||||
|
(sd_model, 'embedder'),
|
||||||
|
(sd_model, 'model'),
|
||||||
|
(sd_model, 'embedder'),
|
||||||
|
]
|
||||||
|
|
||||||
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then
|
is_sdxl = hasattr(sd_model, 'conditioner')
|
||||||
# send the model to GPU. Then put modules back. the modules will be in CPU.
|
is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')
|
||||||
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model
|
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None
|
if is_sdxl:
|
||||||
|
to_remain_in_cpu.append((sd_model, 'conditioner'))
|
||||||
|
elif is_sd2:
|
||||||
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'model'))
|
||||||
|
else:
|
||||||
|
to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer'))
|
||||||
|
|
||||||
|
# remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model
|
||||||
|
stored = []
|
||||||
|
for obj, field in to_remain_in_cpu:
|
||||||
|
module = getattr(obj, field, None)
|
||||||
|
stored.append(module)
|
||||||
|
setattr(obj, field, None)
|
||||||
|
|
||||||
|
# send the model to GPU.
|
||||||
sd_model.to(devices.device)
|
sd_model.to(devices.device)
|
||||||
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored
|
|
||||||
|
# put modules back. the modules will be in CPU.
|
||||||
|
for (obj, field), module in zip(to_remain_in_cpu, stored):
|
||||||
|
setattr(obj, field, module)
|
||||||
|
|
||||||
# register hooks for those the first three models
|
# register hooks for those the first three models
|
||||||
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
if is_sdxl:
|
||||||
|
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
elif is_sd2:
|
||||||
|
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
else:
|
||||||
|
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
|
||||||
|
|
||||||
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
sd_model.first_stage_model.encode = first_stage_model_encode_wrap
|
||||||
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
sd_model.first_stage_model.decode = first_stage_model_decode_wrap
|
||||||
@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram):
|
|||||||
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
|
||||||
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
|
||||||
|
|
||||||
if hasattr(sd_model.cond_stage_model, 'model'):
|
|
||||||
sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer
|
|
||||||
del sd_model.cond_stage_model.transformer
|
|
||||||
|
|
||||||
if use_medvram:
|
if use_medvram:
|
||||||
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
|
||||||
else:
|
else:
|
||||||
|
@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl
|
|||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []),
|
(os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),
|
||||||
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
|
||||||
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
|
||||||
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
|
||||||
@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs:
|
|||||||
d = os.path.abspath(d)
|
d = os.path.abspath(d)
|
||||||
if "atstart" in options:
|
if "atstart" in options:
|
||||||
sys.path.insert(0, d)
|
sys.path.insert(0, d)
|
||||||
|
elif "sgm" in options:
|
||||||
|
# Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we
|
||||||
|
# import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.
|
||||||
|
|
||||||
|
sys.path.insert(0, d)
|
||||||
|
import sgm
|
||||||
|
sys.path.pop(0)
|
||||||
else:
|
else:
|
||||||
sys.path.append(d)
|
sys.path.append(d)
|
||||||
paths[what] = d
|
paths[what] = d
|
||||||
|
@ -343,10 +343,13 @@ class StableDiffusionProcessing:
|
|||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
def setup_conds(self):
|
def setup_conds(self):
|
||||||
|
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
|
||||||
|
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
|
||||||
|
|
||||||
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
|
||||||
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
|
||||||
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
|
||||||
|
|
||||||
def parse_extra_network_prompts(self):
|
def parse_extra_network_prompts(self):
|
||||||
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List
|
from typing import List
|
||||||
@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(model, prompts, steps):
|
class SdConditioning(list):
|
||||||
|
"""
|
||||||
|
A list with prompts for stable diffusion's conditioner model.
|
||||||
|
Can also specify width and height of created image - SDXL needs it.
|
||||||
|
"""
|
||||||
|
def __init__(self, prompts, width=None, height=None):
|
||||||
|
super().__init__()
|
||||||
|
self.extend(prompts)
|
||||||
|
self.width = width or getattr(prompts, 'width', None)
|
||||||
|
self.height = height or getattr(prompts, 'height', None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
|
||||||
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
|
||||||
and the sampling step at which this condition is to be replaced by the next one.
|
and the sampling step at which this condition is to be replaced by the next one.
|
||||||
|
|
||||||
@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
re_AND = re.compile(r"\bAND\b")
|
re_AND = re.compile(r"\bAND\b")
|
||||||
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
|
||||||
|
|
||||||
def get_multicond_prompt_list(prompts):
|
|
||||||
|
def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
|
||||||
res_indexes = []
|
res_indexes = []
|
||||||
|
|
||||||
prompt_flat_list = []
|
|
||||||
prompt_indexes = {}
|
prompt_indexes = {}
|
||||||
|
prompt_flat_list = SdConditioning(prompts)
|
||||||
|
prompt_flat_list.clear()
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
subprompts = re_AND.split(prompt)
|
subprompts = re_AND.split(prompt)
|
||||||
@ -201,6 +217,7 @@ class MulticondLearnedConditioning:
|
|||||||
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS
|
||||||
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
|
||||||
|
|
||||||
|
|
||||||
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:
|
||||||
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
|
||||||
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
For each prompt, the list is obtained by splitting the prompt using the AND separator.
|
||||||
|
@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
|
|||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
import ldm.modules.encoders.modules
|
import ldm.modules.encoders.modules
|
||||||
|
|
||||||
|
import sgm.modules.attention
|
||||||
|
import sgm.modules.diffusionmodules.model
|
||||||
|
import sgm.modules.diffusionmodules.openaimodel
|
||||||
|
import sgm.modules.encoders.modules
|
||||||
|
|
||||||
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
|
||||||
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
@ -56,6 +61,9 @@ def apply_optimizations(option=None):
|
|||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
|
sgm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
|
sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
if current_optimizer is not None:
|
if current_optimizer is not None:
|
||||||
current_optimizer.undo()
|
current_optimizer.undo()
|
||||||
current_optimizer = None
|
current_optimizer = None
|
||||||
@ -89,6 +97,10 @@ def undo_optimizations():
|
|||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
|
||||||
|
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
def fix_checkpoint():
|
def fix_checkpoint():
|
||||||
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
|
||||||
@ -170,10 +182,19 @@ class StableDiffusionModelHijack:
|
|||||||
if conditioner:
|
if conditioner:
|
||||||
for i in range(len(conditioner.embedders)):
|
for i in range(len(conditioner.embedders)):
|
||||||
embedder = conditioner.embedders[i]
|
embedder = conditioner.embedders[i]
|
||||||
if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder':
|
typename = type(embedder).__name__
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder':
|
||||||
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||||
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
|
||||||
conditioner.embedders[i] = m.cond_stage_model
|
conditioner.embedders[i] = m.cond_stage_model
|
||||||
|
if typename == 'FrozenCLIPEmbedder':
|
||||||
|
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
|
||||||
|
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
|
||||||
|
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
|
||||||
|
conditioner.embedders[i] = m.cond_stage_model
|
||||||
|
if typename == 'FrozenOpenCLIPEmbedder2':
|
||||||
|
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
|
||||||
|
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
|
||||||
|
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||||
|
@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
|
||||||
self.chunk_length = 75
|
self.chunk_length = 75
|
||||||
|
|
||||||
|
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
||||||
|
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
||||||
|
self.legacy_ucg_val = None
|
||||||
|
|
||||||
def empty_chunk(self):
|
def empty_chunk(self):
|
||||||
"""creates an empty PromptChunk and returns it"""
|
"""creates an empty PromptChunk and returns it"""
|
||||||
|
|
||||||
@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
|
||||||
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
|
||||||
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
|
be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280.
|
||||||
An example shape returned by this function can be: (2, 77, 768).
|
An example shape returned by this function can be: (2, 77, 768).
|
||||||
|
For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.
|
||||||
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
|
||||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||||
"""
|
"""
|
||||||
@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
|
||||||
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
|
||||||
|
|
||||||
return torch.hstack(zs)
|
if getattr(self.wrapped, 'return_pooled', False):
|
||||||
|
return torch.hstack(zs), zs[0].pooled
|
||||||
|
else:
|
||||||
|
return torch.hstack(zs)
|
||||||
|
|
||||||
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
def process_tokens(self, remade_batch_tokens, batch_multipliers):
|
||||||
"""
|
"""
|
||||||
@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
z = z * (original_mean / new_mean)
|
z *= (original_mean / new_mean)
|
||||||
|
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
|||||||
self.id_end = tokenizer.encoder["<end_of_text>"]
|
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||||
self.id_pad = 0
|
self.id_pad = 0
|
||||||
|
|
||||||
self.is_trainable = getattr(wrapped, 'is_trainable', False)
|
|
||||||
self.input_key = getattr(wrapped, 'input_key', 'txt')
|
|
||||||
self.legacy_ucg_val = None
|
|
||||||
|
|
||||||
def tokenize(self, texts):
|
def tokenize(self, texts):
|
||||||
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||||
|
|
||||||
@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
|
|||||||
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||||
|
|
||||||
return embedded
|
return embedded
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):
|
||||||
|
def __init__(self, wrapped, hijack):
|
||||||
|
super().__init__(wrapped, hijack)
|
||||||
|
|
||||||
|
self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]
|
||||||
|
self.id_start = tokenizer.encoder["<start_of_text>"]
|
||||||
|
self.id_end = tokenizer.encoder["<end_of_text>"]
|
||||||
|
self.id_pad = 0
|
||||||
|
|
||||||
|
def tokenize(self, texts):
|
||||||
|
assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'
|
||||||
|
|
||||||
|
tokenized = [tokenizer.encode(text) for text in texts]
|
||||||
|
|
||||||
|
return tokenized
|
||||||
|
|
||||||
|
def encode_with_transformers(self, tokens):
|
||||||
|
d = self.wrapped.encode_with_transformer(tokens)
|
||||||
|
z = d[self.wrapped.layer]
|
||||||
|
|
||||||
|
pooled = d.get("pooled")
|
||||||
|
if pooled is not None:
|
||||||
|
z.pooled = pooled
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
def encode_embedding_init_text(self, init_text, nvpt):
|
||||||
|
ids = tokenizer.encode(init_text)
|
||||||
|
ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
|
||||||
|
embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
|
||||||
|
|
||||||
|
return embedded
|
||||||
|
@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork
|
|||||||
import ldm.modules.attention
|
import ldm.modules.attention
|
||||||
import ldm.modules.diffusionmodules.model
|
import ldm.modules.diffusionmodules.model
|
||||||
|
|
||||||
|
import sgm.modules.attention
|
||||||
|
import sgm.modules.diffusionmodules.model
|
||||||
|
|
||||||
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimization:
|
class SdOptimization:
|
||||||
@ -39,6 +43,9 @@ class SdOptimization:
|
|||||||
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationXformers(SdOptimization):
|
class SdOptimizationXformers(SdOptimization):
|
||||||
name = "xformers"
|
name = "xformers"
|
||||||
@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = xformers_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSdpNoMem(SdOptimization):
|
class SdOptimizationSdpNoMem(SdOptimization):
|
||||||
@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
||||||
@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationSubQuad(SdOptimization):
|
class SdOptimizationSubQuad(SdOptimization):
|
||||||
@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationV1(SdOptimization):
|
class SdOptimizationV1(SdOptimization):
|
||||||
@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization):
|
|||||||
cmd_opt = "opt_split_attention_v1"
|
cmd_opt = "opt_split_attention_v1"
|
||||||
priority = 10
|
priority = 10
|
||||||
|
|
||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationInvokeAI(SdOptimization):
|
class SdOptimizationInvokeAI(SdOptimization):
|
||||||
@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization):
|
|||||||
|
|
||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
|
||||||
|
|
||||||
|
|
||||||
class SdOptimizationDoggettx(SdOptimization):
|
class SdOptimizationDoggettx(SdOptimization):
|
||||||
@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization):
|
|||||||
def apply(self):
|
def apply(self):
|
||||||
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
|
sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward
|
||||||
|
sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
|
||||||
|
|
||||||
|
|
||||||
def list_optimizers(res):
|
def list_optimizers(res):
|
||||||
@ -155,7 +173,7 @@ def get_available_vram():
|
|||||||
|
|
||||||
|
|
||||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
|
|
||||||
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
# taken from https://github.com/Doggettx/stable-diffusion and modified
|
||||||
def split_cross_attention_forward(self, x, context=None, mask=None):
|
def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_compvis(q, k, v):
|
def einsum_op_compvis(q, k, v):
|
||||||
s = einsum('b i d, b j d -> b i j', q, k)
|
s = einsum('b i d, b j d -> b i j', q, k)
|
||||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||||
return einsum('b i j, b j d -> b i d', s, v)
|
return einsum('b i j, b j d -> b i d', s, v)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_slice_0(q, k, v, slice_size):
|
def einsum_op_slice_0(q, k, v, slice_size):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[0], slice_size):
|
for i in range(0, q.shape[0], slice_size):
|
||||||
@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size):
|
|||||||
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_slice_1(q, k, v, slice_size):
|
def einsum_op_slice_1(q, k, v, slice_size):
|
||||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size):
|
|||||||
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_mps_v1(q, k, v):
|
def einsum_op_mps_v1(q, k, v):
|
||||||
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v):
|
|||||||
slice_size -= 1
|
slice_size -= 1
|
||||||
return einsum_op_slice_1(q, k, v, slice_size)
|
return einsum_op_slice_1(q, k, v, slice_size)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_mps_v2(q, k, v):
|
def einsum_op_mps_v2(q, k, v):
|
||||||
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
|
||||||
return einsum_op_compvis(q, k, v)
|
return einsum_op_compvis(q, k, v)
|
||||||
else:
|
else:
|
||||||
return einsum_op_slice_0(q, k, v, 1)
|
return einsum_op_slice_0(q, k, v, 1)
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
||||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||||
if size_mb <= max_tensor_mb:
|
if size_mb <= max_tensor_mb:
|
||||||
@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb):
|
|||||||
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
return einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||||
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||||
|
|
||||||
|
|
||||||
def einsum_op_cuda(q, k, v):
|
def einsum_op_cuda(q, k, v):
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v):
|
|||||||
# Divide factor of safety as there's copying and fragmentation
|
# Divide factor of safety as there's copying and fragmentation
|
||||||
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||||
|
|
||||||
|
|
||||||
def einsum_op(q, k, v):
|
def einsum_op(q, k, v):
|
||||||
if q.device.type == 'cuda':
|
if q.device.type == 'cuda':
|
||||||
return einsum_op_cuda(q, k, v)
|
return einsum_op_cuda(q, k, v)
|
||||||
@ -328,7 +354,8 @@ def einsum_op(q, k, v):
|
|||||||
# Tested on i7 with 8MB L3 cache.
|
# Tested on i7 with 8MB L3 cache.
|
||||||
return einsum_op_tensor_mem(q, k, v, 32)
|
return einsum_op_tensor_mem(q, k, v, 32)
|
||||||
|
|
||||||
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|
||||||
|
def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||||
|
|
||||||
h = self.heads
|
h = self.heads
|
||||||
@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||||
batch_x_heads, q_tokens, _ = q.shape
|
batch_x_heads, q_tokens, _ = q.shape
|
||||||
@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
h = self.heads
|
h = self.heads
|
||||||
q_in = self.to_q(x)
|
q_in = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
|
||||||
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
|
||||||
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
|
||||||
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
batch_size, sequence_length, inner_dim = x.shape
|
batch_size, sequence_length, inner_dim = x.shape
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
hidden_states = self.to_out[1](hidden_states)
|
hidden_states = self.to_out[1](hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
|
|
||||||
|
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
return scaled_dot_product_attention_forward(self, x, context, mask)
|
return scaled_dot_product_attention_forward(self, x, context, mask)
|
||||||
|
|
||||||
|
|
||||||
def cross_attention_attnblock_forward(self, x):
|
def cross_attention_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
return cross_attention_attnblock_forward(self, x)
|
return cross_attention_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
def sdp_attnblock_forward(self, x):
|
def sdp_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x):
|
|||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
|
||||||
|
|
||||||
def sdp_no_mem_attnblock_forward(self, x):
|
def sdp_no_mem_attnblock_forward(self, x):
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
||||||
return sdp_attnblock_forward(self, x)
|
return sdp_attnblock_forward(self, x)
|
||||||
|
|
||||||
|
|
||||||
def sub_quad_attnblock_forward(self, x):
|
def sub_quad_attnblock_forward(self, x):
|
||||||
h_ = x
|
h_ = x
|
||||||
h_ = self.norm(h_)
|
h_ = self.norm(h_)
|
||||||
|
@ -411,6 +411,7 @@ def repair_config(sd_config):
|
|||||||
|
|
||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
|
||||||
|
|
||||||
|
|
||||||
class SdModelData:
|
class SdModelData:
|
||||||
@ -445,6 +446,15 @@ class SdModelData:
|
|||||||
model_data = SdModelData()
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def get_empty_cond(sd_model):
|
||||||
|
if hasattr(sd_model, 'conditioner'):
|
||||||
|
d = sd_model.get_learned_conditioning([""])
|
||||||
|
return d['crossattn']
|
||||||
|
else:
|
||||||
|
return sd_model.cond_stage_model([""])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||||
|
|
||||||
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
|
||||||
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
|
clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
|
||||||
|
|
||||||
timer.record("find config")
|
timer.record("find config")
|
||||||
|
|
||||||
@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
timer.record("scripts callbacks")
|
timer.record("scripts callbacks")
|
||||||
|
|
||||||
with devices.autocast(), torch.no_grad():
|
with devices.autocast(), torch.no_grad():
|
||||||
sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
|
sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
|
||||||
|
|
||||||
timer.record("calculate empty prompt")
|
timer.record("calculate empty prompt")
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
|
|||||||
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
|
||||||
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
|
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
|
||||||
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
|
||||||
|
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
|
||||||
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
|
||||||
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
|
||||||
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
|
||||||
@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename):
|
|||||||
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
|
||||||
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
|
||||||
|
|
||||||
if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
|
||||||
|
return config_sdxl
|
||||||
|
elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
|
||||||
return config_depth_model
|
return config_depth_model
|
||||||
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
|
||||||
return config_unclip
|
return config_unclip
|
||||||
|
@ -1,18 +1,30 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import sgm.models.diffusion
|
import sgm.models.diffusion
|
||||||
import sgm.modules.diffusionmodules.denoiser_scaling
|
import sgm.modules.diffusionmodules.denoiser_scaling
|
||||||
import sgm.modules.diffusionmodules.discretizer
|
import sgm.modules.diffusionmodules.discretizer
|
||||||
from modules import devices
|
from modules import devices, shared, prompt_parser
|
||||||
|
|
||||||
|
|
||||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]):
|
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||||
for embedder in self.conditioner.embedders:
|
for embedder in self.conditioner.embedders:
|
||||||
embedder.ucg_rate = 0.0
|
embedder.ucg_rate = 0.0
|
||||||
|
|
||||||
c = self.conditioner({'txt': batch})
|
width = getattr(self, 'target_width', 1024)
|
||||||
|
height = getattr(self, 'target_height', 1024)
|
||||||
|
|
||||||
|
sdxl_conds = {
|
||||||
|
"txt": batch,
|
||||||
|
"original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
||||||
|
"crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
||||||
|
"target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
|
||||||
|
}
|
||||||
|
|
||||||
|
c = self.conditioner(sdxl_conds)
|
||||||
|
|
||||||
return c
|
return c
|
||||||
|
|
||||||
@ -26,7 +38,7 @@ def extend_sdxl(model):
|
|||||||
model.model.diffusion_model.dtype = dtype
|
model.model.diffusion_model.dtype = dtype
|
||||||
model.model.conditioning_key = 'crossattn'
|
model.model.conditioning_key = 'crossattn'
|
||||||
|
|
||||||
model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0]
|
model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0]
|
||||||
model.cond_stage_key = model.cond_stage_model.input_key
|
model.cond_stage_key = model.cond_stage_model.input_key
|
||||||
|
|
||||||
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"
|
||||||
@ -34,7 +46,14 @@ def extend_sdxl(model):
|
|||||||
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
|
||||||
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)
|
||||||
|
|
||||||
|
model.is_xl = True
|
||||||
|
|
||||||
|
|
||||||
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
|
||||||
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
sgm.models.diffusion.DiffusionEngine.apply_model = apply_model
|
||||||
|
|
||||||
|
sgm.modules.attention.print = lambda *args: None
|
||||||
|
sgm.modules.diffusionmodules.model.print = lambda *args: None
|
||||||
|
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
|
||||||
|
sgm.modules.encoders.modules.print = lambda *args: None
|
||||||
|
|
||||||
|
@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||||
a = batch_offset
|
a = batch_offset
|
||||||
b = a + batch_size
|
b = a + batch_size
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b]))
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||||
else:
|
else:
|
||||||
x_out = torch.zeros_like(x_in)
|
x_out = torch.zeros_like(x_in)
|
||||||
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
|
||||||
|
@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
|
||||||
|
"sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"),
|
||||||
|
"sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
options_templates.update(options_section(('optimizations', "Optimizations"), {
|
||||||
|
@ -14,6 +14,7 @@ kornia
|
|||||||
lark
|
lark
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
|
open-clip-torch
|
||||||
|
|
||||||
piexif
|
piexif
|
||||||
psutil
|
psutil
|
||||||
|
@ -15,6 +15,7 @@ kornia==0.6.7
|
|||||||
lark==1.1.2
|
lark==1.1.2
|
||||||
numpy==1.23.5
|
numpy==1.23.5
|
||||||
omegaconf==2.2.3
|
omegaconf==2.2.3
|
||||||
|
open-clip-torch==2.20.0
|
||||||
piexif==1.1.3
|
piexif==1.1.3
|
||||||
psutil~=5.9.5
|
psutil~=5.9.5
|
||||||
pytorch_lightning==1.9.4
|
pytorch_lightning==1.9.4
|
||||||
|
Loading…
Reference in New Issue
Block a user