make it possible to use hypernetworks without opt split attention

This commit is contained in:
AUTOMATIC 2022-10-07 16:39:51 +03:00
parent 97bc0b9504
commit f7c787eb7c
2 changed files with 38 additions and 10 deletions

View File

@ -4,7 +4,12 @@ import sys
import traceback
import torch
from modules import devices
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
@ -48,15 +53,36 @@ def load_hypernetworks(path):
return res
def apply(self, x, context=None, mask=None, original=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
if context.shape[1] == 77 and CrossAttention.noise_cond:
context = context + (torch.randn_like(context) * 0.1)
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
k = self.to_k(h_k(context))
v = self.to_v(h_v(context))
q = self.to_q(x)
context = default(context, x)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

View File

@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts
import ldm.modules.attention
@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1:
@ -30,7 +32,7 @@ def apply_optimizations():
def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward