add fallback for xformers_attnblock_forward

This commit is contained in:
AUTOMATIC 2022-10-08 19:05:19 +03:00
parent a5550f0213
commit f9c5da1592

View File

@ -211,6 +211,7 @@ def cross_attention_attnblock_forward(self, x):
return h3 return h3
def xformers_attnblock_forward(self, x): def xformers_attnblock_forward(self, x):
try:
h_ = x h_ = x
h_ = self.norm(h_) h_ = self.norm(h_)
q1 = self.q(h_).contiguous() q1 = self.q(h_).contiguous()
@ -218,4 +219,6 @@ def xformers_attnblock_forward(self, x):
v = self.v(h_).contiguous() v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v) out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out) out = self.proj_out(out)
return x+out return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)