diff --git a/launch.py b/launch.py index 75edb66a..a592e1ba 100644 --- a/launch.py +++ b/launch.py @@ -4,6 +4,7 @@ import os import sys import importlib.util import shlex +import platform dir_repos = "repositories" dir_tmp = "tmp" @@ -31,6 +32,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') +args, xformers = extract_arg(args, '--xformers') def repo_dir(name): @@ -124,6 +126,12 @@ if not is_installed("gfpgan"): if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") +if not is_installed("xformers") and xformers: + if platform.system() == "Windows": + run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers") + elif platform.system() == "Linux": + run_pip("install xformers", "xformers") + os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ba808a39..5d93f7f6 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,11 +22,13 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - - if cmd_opts.opt_split_attention_v1: + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + elif cmd_opts.opt_split_attention or torch.cuda.is_available(): + ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 3351c740..e43e2c7a 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,14 @@ import math import torch from torch import einsum - +try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True +except: + print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.') + continue from ldm.util import default from einops import rearrange @@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +def xformers_attention_forward(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + out = rearrange(out, 'b n h d -> b n (h d)', h=h) + return self.to_out(out) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x return h3 + +def xformers_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_).contiguous() + k1 = self.k(h_).contiguous() + v = self.v(h_).contiguous() + out = xformers.ops.memory_efficient_attention(q1, k1, v) + out = self.proj_out(out) + return x+out diff --git a/modules/shared.py b/modules/shared.py index 475d7e52..d68df751 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) +parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") @@ -73,7 +74,7 @@ device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram - +xformers_available = False config_filename = cmd_opts.ui_settings_file hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) diff --git a/requirements.txt b/requirements.txt index 631fe616..81641d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ resize-right torchdiffeq kornia lark +functorch diff --git a/requirements_versions.txt b/requirements_versions.txt index fdff2687..fec3e9d5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,3 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 +functorch==0.2.1