From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: [PATCH 1/2] Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- README.md | 1 + modules/deepbooru_model.py | 4 +++- modules/devices.py | 2 ++ modules/processing.py | 15 ++++++++------- modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++ modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++ modules/sd_models.py | 10 ++++++++++ modules/shared.py | 1 + 8 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 modules/sd_hijack_utils.py diff --git a/README.md b/README.md index 9c0cd1ef..a5611671 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py index edd40c81..83d2ff09 100644 --- a/modules/deepbooru_model.py +++ b/modules/deepbooru_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import devices + # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more @@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): t_358, = inputs t_359 = t_358.permute(*[0, 3, 1, 2]) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) - t_360 = self.n_Conv_0(t_359_padded) + t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) t_361 = F.relu(t_360) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_362 = self.n_MaxPool_0(t_361) diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..0981ef80 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -79,6 +79,8 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False def randn(seed, shape): diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..2d186ba0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,8 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -203,7 +204,7 @@ class StableDiffusionProcessing: # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) + conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -211,7 +212,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -225,10 +226,10 @@ class StableDiffusionProcessing: # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) + return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): + with devices.autocast(disable=devices.unet_needs_upcast): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] @@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc class TorchHijackForUnet: @@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 00000000..f81b169a --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-2, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() diff --git a/modules/shared.py b/modules/shared.py index 5f713bee..4ce1209b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: [PATCH 2/2] Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 ++++++++++++++++++----------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - del context, x + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice