Merge pull request #6510 from brkirch/unet16-upcast-precision
Add upcast options, full precision sampling from float16 UNet and upcasting attention for inference using SD 2.1 models without --no-half
This commit is contained in:
commit
1574e96729
@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
|
||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
|
- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6)
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -2,6 +2,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
|
||||||
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
# see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
|
||||||
|
|
||||||
|
|
||||||
@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module):
|
|||||||
t_358, = inputs
|
t_358, = inputs
|
||||||
t_359 = t_358.permute(*[0, 3, 1, 2])
|
t_359 = t_358.permute(*[0, 3, 1, 2])
|
||||||
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
|
||||||
t_360 = self.n_Conv_0(t_359_padded)
|
t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
|
||||||
t_361 = F.relu(t_360)
|
t_361 = F.relu(t_360)
|
||||||
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
|
||||||
t_362 = self.n_MaxPool_0(t_361)
|
t_362 = self.n_MaxPool_0(t_361)
|
||||||
|
@ -79,6 +79,8 @@ cpu = torch.device("cpu")
|
|||||||
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
dtype_vae = torch.float16
|
dtype_vae = torch.float16
|
||||||
|
dtype_unet = torch.float16
|
||||||
|
unet_needs_upcast = False
|
||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
@ -106,6 +108,10 @@ def autocast(disable=False):
|
|||||||
return torch.autocast("cuda")
|
return torch.autocast("cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def without_autocast(disable=False):
|
||||||
|
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||||
|
|
||||||
|
|
||||||
class NansException(Exception):
|
class NansException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -123,7 +129,7 @@ def test_for_nans(x, where):
|
|||||||
message = "A tensor with all NaNs was produced in Unet."
|
message = "A tensor with all NaNs was produced in Unet."
|
||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
|
message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."
|
||||||
|
|
||||||
elif where == "vae":
|
elif where == "vae":
|
||||||
message = "A tensor with all NaNs was produced in VAE."
|
message = "A tensor with all NaNs was produced in VAE."
|
||||||
|
@ -172,7 +172,8 @@ class StableDiffusionProcessing:
|
|||||||
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
|
||||||
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
|
||||||
|
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
|
||||||
|
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
|
||||||
conditioning = torch.nn.functional.interpolate(
|
conditioning = torch.nn.functional.interpolate(
|
||||||
self.sd_model.depth_model(midas_in),
|
self.sd_model.depth_model(midas_in),
|
||||||
size=conditioning_image.shape[2:],
|
size=conditioning_image.shape[2:],
|
||||||
@ -203,7 +204,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
# Create another latent image, this time with a masked version of the original input.
|
# Create another latent image, this time with a masked version of the original input.
|
||||||
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
|
||||||
conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
|
conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype)
|
||||||
conditioning_image = torch.lerp(
|
conditioning_image = torch.lerp(
|
||||||
source_image,
|
source_image,
|
||||||
source_image * (1.0 - conditioning_mask),
|
source_image * (1.0 - conditioning_mask),
|
||||||
@ -211,7 +212,7 @@ class StableDiffusionProcessing:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Encode the new masked image using first stage of network.
|
# Encode the new masked image using first stage of network.
|
||||||
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
|
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
|
||||||
|
|
||||||
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
# Create the concatenated conditioning tensor to be fed to `c_concat`
|
||||||
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
|
||||||
@ -225,10 +226,10 @@ class StableDiffusionProcessing:
|
|||||||
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
# HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
|
||||||
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
# identify itself with a field common to all models. The conditioning_key is also hybrid.
|
||||||
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
|
||||||
return self.depth2img_image_conditioning(source_image)
|
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
|
||||||
|
|
||||||
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
|
||||||
return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
|
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
|
||||||
|
|
||||||
# Dummy zero conditioning if we're not using inpainting or depth model.
|
# Dummy zero conditioning if we're not using inpainting or depth model.
|
||||||
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
|
||||||
@ -614,7 +615,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||||
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
|
||||||
|
|
||||||
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
|
||||||
@ -992,7 +993,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
image = torch.from_numpy(batch_images)
|
image = torch.from_numpy(batch_images)
|
||||||
image = 2. * image - 1.
|
image = 2. * image - 1.
|
||||||
image = image.to(shared.device)
|
image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
|
||||||
|
|
||||||
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from torch import einsum
|
|||||||
from ldm.util import default
|
from ldm.util import default
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from modules import shared, errors
|
from modules import shared, errors, devices
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||||
@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
dtype = q.dtype
|
||||||
for i in range(0, q.shape[0], 2):
|
if shared.opts.upcast_attn:
|
||||||
end = i + 2
|
q, k, v = q.float(), k.float(), v.float()
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
|
||||||
s1 *= self.scale
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
del s1
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
for i in range(0, q.shape[0], 2):
|
||||||
|
end = i + 2
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
|
s1 *= self.scale
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
|
del s2
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1 = r1.to(dtype)
|
||||||
del s2
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
del r1
|
del r1
|
||||||
@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
k_in *= self.scale
|
dtype = q_in.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()
|
||||||
|
|
||||||
del context, x
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
|
del context, x
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
||||||
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
|
gb = 1024 ** 3
|
||||||
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
|
mem_required = tensor_size * modifier
|
||||||
|
steps = 1
|
||||||
|
|
||||||
|
if mem_required > mem_free_total:
|
||||||
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
|
if steps > 64:
|
||||||
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
for i in range(0, q.shape[1], slice_size):
|
||||||
|
end = i + slice_size
|
||||||
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
|
del s1
|
||||||
|
|
||||||
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
|
del s2
|
||||||
|
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
r1 = r1.to(dtype)
|
||||||
del q_in, k_in, v_in
|
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
|
||||||
|
|
||||||
gb = 1024 ** 3
|
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
|
||||||
mem_required = tensor_size * modifier
|
|
||||||
steps = 1
|
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
|
||||||
|
|
||||||
if steps > 64:
|
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
|
||||||
end = i + slice_size
|
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
|
||||||
del s1
|
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
|
||||||
del s2
|
|
||||||
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
|
||||||
del r1
|
del r1
|
||||||
@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
|
||||||
k = self.to_k(context_k) * self.scale
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
dtype = q.dtype
|
||||||
r = einsum_op(q, k, v)
|
if shared.opts.upcast_attn:
|
||||||
|
q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()
|
||||||
|
|
||||||
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
|
k = k * self.scale
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||||
|
r = einsum_op(q, k, v)
|
||||||
|
r = r.to(dtype)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
|
|
||||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||||
@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
|
||||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||||
|
|
||||||
|
x = x.to(dtype)
|
||||||
|
|
||||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||||
|
|
||||||
out_proj, dropout = self.to_out
|
out_proj, dropout = self.to_out
|
||||||
@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
query_chunk_size = q_tokens
|
query_chunk_size = q_tokens
|
||||||
kv_chunk_size = k_tokens
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
return efficient_dot_product_attention(
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||||
q,
|
return efficient_dot_product_attention(
|
||||||
k,
|
q,
|
||||||
v,
|
k,
|
||||||
query_chunk_size=q_chunk_size,
|
v,
|
||||||
kv_chunk_size=kv_chunk_size,
|
query_chunk_size=q_chunk_size,
|
||||||
kv_chunk_size_min = kv_chunk_size_min,
|
kv_chunk_size=kv_chunk_size,
|
||||||
use_checkpoint=use_checkpoint,
|
kv_chunk_size_min = kv_chunk_size_min,
|
||||||
)
|
use_checkpoint=use_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_xformers_flash_attention_op(q, k, v):
|
def get_xformers_flash_attention_op(q, k, v):
|
||||||
@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
|
||||||
|
|
||||||
|
out = out.to(dtype)
|
||||||
|
|
||||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
|
|||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||||
|
dtype = q.dtype
|
||||||
|
if shared.opts.upcast_attn:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
|
||||||
|
out = out.to(dtype)
|
||||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||||
out = self.proj_out(out)
|
out = self.proj_out(out)
|
||||||
return x + out
|
return x + out
|
||||||
|
@ -1,4 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from modules import devices
|
||||||
|
from modules.sd_hijack_utils import CondFunc
|
||||||
|
|
||||||
|
|
||||||
class TorchHijackForUnet:
|
class TorchHijackForUnet:
|
||||||
@ -28,3 +32,28 @@ class TorchHijackForUnet:
|
|||||||
|
|
||||||
|
|
||||||
th = TorchHijackForUnet()
|
th = TorchHijackForUnet()
|
||||||
|
|
||||||
|
|
||||||
|
# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
|
||||||
|
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
|
||||||
|
for y in cond.keys():
|
||||||
|
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
|
||||||
|
with devices.autocast():
|
||||||
|
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
|
||||||
|
|
||||||
|
class GELUHijack(torch.nn.GELU, torch.nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
torch.nn.GELU.__init__(self, *args, **kwargs)
|
||||||
|
def forward(self, x):
|
||||||
|
if devices.unet_needs_upcast:
|
||||||
|
return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet)
|
||||||
|
else:
|
||||||
|
return torch.nn.GELU.forward(self, x)
|
||||||
|
|
||||||
|
unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast
|
||||||
|
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast)
|
||||||
|
CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
|
if version.parse(torch.__version__) <= version.parse("1.13.1"):
|
||||||
|
CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast)
|
||||||
|
CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast)
|
||||||
|
CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU)
|
||||||
|
28
modules/sd_hijack_utils.py
Normal file
28
modules/sd_hijack_utils.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
class CondFunc:
|
||||||
|
def __new__(cls, orig_func, sub_func, cond_func):
|
||||||
|
self = super(CondFunc, cls).__new__(cls)
|
||||||
|
if isinstance(orig_func, str):
|
||||||
|
func_path = orig_func.split('.')
|
||||||
|
for i in range(len(func_path)-2, -1, -1):
|
||||||
|
try:
|
||||||
|
resolved_obj = importlib.import_module('.'.join(func_path[:i]))
|
||||||
|
break
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
for attr_name in func_path[i:-1]:
|
||||||
|
resolved_obj = getattr(resolved_obj, attr_name)
|
||||||
|
orig_func = getattr(resolved_obj, func_path[-1])
|
||||||
|
setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs))
|
||||||
|
self.__init__(orig_func, sub_func, cond_func)
|
||||||
|
return lambda *args, **kwargs: self(*args, **kwargs)
|
||||||
|
def __init__(self, orig_func, sub_func, cond_func):
|
||||||
|
self.__orig_func = orig_func
|
||||||
|
self.__sub_func = sub_func
|
||||||
|
self.__cond_func = cond_func
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs):
|
||||||
|
return self.__sub_func(self.__orig_func, *args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.__orig_func(*args, **kwargs)
|
@ -258,16 +258,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo):
|
|||||||
|
|
||||||
if not shared.cmd_opts.no_half:
|
if not shared.cmd_opts.no_half:
|
||||||
vae = model.first_stage_model
|
vae = model.first_stage_model
|
||||||
|
depth_model = getattr(model, 'depth_model', None)
|
||||||
|
|
||||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||||
if shared.cmd_opts.no_half_vae:
|
if shared.cmd_opts.no_half_vae:
|
||||||
model.first_stage_model = None
|
model.first_stage_model = None
|
||||||
|
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||||
|
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||||
|
model.depth_model = None
|
||||||
|
|
||||||
model.half()
|
model.half()
|
||||||
model.first_stage_model = vae
|
model.first_stage_model = vae
|
||||||
|
if depth_model:
|
||||||
|
model.depth_model = depth_model
|
||||||
|
|
||||||
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
|
||||||
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
|
||||||
|
devices.dtype_unet = model.model.diffusion_model.dtype
|
||||||
|
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||||
|
|
||||||
model.first_stage_model.to(devices.dtype_vae)
|
model.first_stage_model.to(devices.dtype_vae)
|
||||||
|
|
||||||
@ -382,6 +390,8 @@ def load_model(checkpoint_info=None):
|
|||||||
|
|
||||||
if shared.cmd_opts.no_half:
|
if shared.cmd_opts.no_half:
|
||||||
sd_config.model.params.unet_config.params.use_fp16 = False
|
sd_config.model.params.unet_config.params.use_fp16 = False
|
||||||
|
elif shared.cmd_opts.upcast_sampling:
|
||||||
|
sd_config.model.params.unet_config.params.use_fp16 = True
|
||||||
|
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
|
|||||||
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
|
||||||
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
|
||||||
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
|
||||||
|
parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.")
|
||||||
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
|
||||||
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
|
||||||
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
|
||||||
@ -409,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
|
@ -67,7 +67,7 @@ def _summarize_chunk(
|
|||||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||||
max_score = max_score.detach()
|
max_score = max_score.detach()
|
||||||
exp_weights = torch.exp(attn_weights - max_score)
|
exp_weights = torch.exp(attn_weights - max_score)
|
||||||
exp_values = torch.bmm(exp_weights, value)
|
exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
|
||||||
max_score = max_score.squeeze(-1)
|
max_score = max_score.squeeze(-1)
|
||||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
|
|||||||
)
|
)
|
||||||
attn_probs = attn_scores.softmax(dim=-1)
|
attn_probs = attn_scores.softmax(dim=-1)
|
||||||
del attn_scores
|
del attn_scores
|
||||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
|
||||||
return hidden_states_slice
|
return hidden_states_slice
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user