diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 8625690b..3eeb84d5 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") -parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) -parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") -parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") -parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization") +parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization") +parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*") +parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 3c21a362..40f388a5 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -110,6 +110,7 @@ callback_map = dict( callbacks_script_unloaded=[], callbacks_before_ui=[], callbacks_on_reload=[], + callbacks_list_optimizers=[], ) @@ -258,6 +259,18 @@ def before_ui_callback(): report_exception(c, 'before_ui') +def list_optimizers_callback(): + res = [] + + for c in callback_map['callbacks_list_optimizers']: + try: + c.callback(res) + except Exception: + report_exception(c, 'list_optimizers') + + return res + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -409,3 +422,11 @@ def on_before_ui(callback): """register a function to be called before the UI is created.""" add_callback(callback_map['callbacks_before_ui'], callback) + + +def on_list_optimizers(callback): + """register a function to be called when UI is making a list of cross attention optimization options. + The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization + to it.""" + + add_callback(callback_map['callbacks_list_optimizers'], callback) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 14e7f799..08d31080 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,7 +3,7 @@ from torch.nn.functional import silu from types import MethodType import modules.textual_inversion.textual_inversion -from modules import devices, sd_hijack_optimizations, shared +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -28,57 +28,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None +optimizers = [] +current_optimizer: sd_hijack_optimizations.SdOptimization = None + + +def list_optimizers(): + new_optimizers = script_callbacks.list_optimizers_callback() + + new_optimizers = [x for x in new_optimizers if x.is_available()] + + new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) + + optimizers.clear() + optimizers.extend(new_optimizers) + def apply_optimizations(): + global current_optimizer + undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th - optimization_method = None + if current_optimizer is not None: + current_optimizer.undo() + current_optimizer = None - can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp + selection = shared.opts.cross_attention_optimization + if selection == "Automatic" and len(optimizers) > 0: + matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) + else: + matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None) - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): - print("Applying xformers cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward - optimization_method = 'xformers' - elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward - optimization_method = 'sdp-no-mem' - elif cmd_opts.opt_sdp_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward - optimization_method = 'sdp' - elif cmd_opts.opt_sub_quad_attention: - print("Applying sub-quadratic cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward - optimization_method = 'sub-quadratic' - elif cmd_opts.opt_split_attention_v1: - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI - optimization_method = 'InvokeAI' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization (Doggettx).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - optimization_method = 'Doggettx' + if selection == "None": + matching_optimizer = None + elif matching_optimizer is None: + matching_optimizer = optimizers[0] - return optimization_method + if matching_optimizer is not None: + print(f"Applying optimization: {matching_optimizer.name}") + matching_optimizer.apply() + current_optimizer = matching_optimizer + return current_optimizer.name + else: + return '' def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward @@ -169,7 +168,11 @@ class StableDiffusionModelHijack: if m.cond_stage_key == "edit": sd_hijack_unet.hijack_ddpm_edit() - self.optimization_method = apply_optimizations() + try: + self.optimization_method = apply_optimizations() + except Exception as e: + errors.display(e, "applying cross attention optimization") + undo_optimizations() self.clip = m.cond_stage_model @@ -223,6 +226,10 @@ class StableDiffusionModelHijack: return token_count, self.clip.get_target_prompt_token_count(token_count) + def redo_hijack(self, m): + self.undo_hijack(m) + self.hijack(m) + class EmbeddingsWithFixes(torch.nn.Module): def __init__(self, wrapped, embeddings): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f00fe55c..0eb4c525 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,10 +9,129 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors, devices +from modules import shared, errors, devices, sub_quadratic_attention from modules.hypernetworks import hypernetwork -from .sub_quadratic_attention import efficient_dot_product_attention +import ldm.modules.attention +import ldm.modules.diffusionmodules.model + +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + + +class SdOptimization: + name: str = None + label: str | None = None + cmd_opt: str | None = None + priority: int = 0 + + def title(self): + if self.label is None: + return self.name + + return f"{self.name} - {self.label}" + + def is_available(self): + return True + + def apply(self): + pass + + def undo(self): + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + + +class SdOptimizationXformers(SdOptimization): + name = "xformers" + cmd_opt = "xformers" + priority = 100 + + def is_available(self): + return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) + + def apply(self): + ldm.modules.attention.CrossAttention.forward = xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + + +class SdOptimizationSdpNoMem(SdOptimization): + name = "sdp-no-mem" + label = "scaled dot product without memory efficient attention" + cmd_opt = "opt_sdp_no_mem_attention" + priority = 90 + + def is_available(self): + return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + + +class SdOptimizationSdp(SdOptimizationSdpNoMem): + name = "sdp" + label = "scaled dot product" + cmd_opt = "opt_sdp_attention" + priority = 80 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + + +class SdOptimizationSubQuad(SdOptimization): + name = "sub-quadratic" + cmd_opt = "opt_sub_quad_attention" + priority = 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + + +class SdOptimizationV1(SdOptimization): + name = "V1" + label = "original v1" + cmd_opt = "opt_split_attention_v1" + priority = 10 + + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + + +class SdOptimizationInvokeAI(SdOptimization): + name = "InvokeAI" + cmd_opt = "opt_split_attention_invokeai" + + @property + def priority(self): + return 1000 if not torch.cuda.is_available() else 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + + +class SdOptimizationDoggettx(SdOptimization): + name = "Doggettx" + cmd_opt = "opt_split_attention" + priority = 20 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + + +def list_optimizers(res): + res.extend([ + SdOptimizationXformers(), + SdOptimizationSdpNoMem(), + SdOptimizationSdp(), + SdOptimizationSubQuad(), + SdOptimizationV1(), + SdOptimizationInvokeAI(), + SdOptimizationDoggettx(), + ]) if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: @@ -299,7 +418,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ kv_chunk_size = k_tokens with devices.without_autocast(disable=q.dtype == v.dtype): - return efficient_dot_product_attention( + return sub_quadratic_attention.efficient_dot_product_attention( q, k, v, diff --git a/modules/shared.py b/modules/shared.py index e53b1e11..3099d1d2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -418,6 +418,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { })) options_templates.update(options_section(('optimizations', "Optimizations"), { + "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), diff --git a/modules/shared_items.py b/modules/shared_items.py index e792a134..2a8713c8 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -21,3 +21,11 @@ def refresh_vae_list(): import modules.sd_vae modules.sd_vae.refresh_vae_list() + + +def cross_attention_optimizations(): + import modules.sd_hijack + + return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] + + diff --git a/webui.py b/webui.py index 5c89a3b8..a76e377c 100644 --- a/webui.py +++ b/webui.py @@ -53,6 +53,7 @@ import modules.img2img import modules.lowvram import modules.scripts import modules.sd_hijack +import modules.sd_hijack_optimizations import modules.sd_models import modules.sd_vae import modules.txt2img @@ -224,6 +225,7 @@ def configure_opts_onchange(): shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) startup_timer.record("opts onchange") @@ -283,6 +285,10 @@ def initialize_rest(*, reload_script_modules=False): modules.textual_inversion.textual_inversion.list_textual_inversion_templates() startup_timer.record("refresh textual inversion templates") + modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) + modules.sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + # load model in parallel to other startup stuff # (when reloading, this does nothing) Thread(target=lambda: shared.sd_model).start() @@ -447,6 +453,10 @@ def webui(): startup_timer.record("scripts unloaded callback") initialize_rest(reload_script_modules=True) + modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) + modules.sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + if __name__ == "__main__": if cmd_opts.nowebui: