diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 498bc9d8..7bbc443e 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -64,21 +64,26 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def apply_hypernetwork(hypernetwork, context): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + def attention_CrossAttention_forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 18408e62..25cb67a4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, hypernetwork + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in @@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): return self.to_out(r2) -# taken from https://github.com/Doggettx/stable-diffusion +# taken from https://github.com/Doggettx/stable-diffusion and modified def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) k_in *= self.scale @@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)