Patch UNet Forward to support resolutions that are not multiples of 64

Also modifed the UI to no longer step in 64
This commit is contained in:
Billy Cao 2022-11-23 18:11:24 +08:00
parent 828438b4a1
commit adb6cb7619
3 changed files with 45 additions and 12 deletions

View File

@ -16,6 +16,7 @@ import ldm.modules.attention
import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
import ldm.modules.diffusionmodules.openaimodel
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@ -26,6 +27,7 @@ def apply_optimizations():
undo_optimizations() undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")

View File

@ -5,6 +5,7 @@ import importlib
import torch import torch
from torch import einsum from torch import einsum
import torch.nn.functional as F
from ldm.util import default from ldm.util import default
from einops import rearrange from einops import rearrange
@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared from modules import shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from ldm.modules.diffusionmodules.util import timestep_embedding
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try: try:
@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
return x + out return x + out
except NotImplementedError: except NotImplementedError:
return cross_attention_attnblock_forward(self, x) return cross_attention_attnblock_forward(self, x)
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
if h.shape[-2:] != hs[-1].shape[-2:]:
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)

View File

@ -380,8 +380,8 @@ def create_seed_inputs():
with gr.Row(visible=False) as seed_extra_row_2: with gr.Row(visible=False) as seed_extra_row_2:
seed_extras.append(seed_extra_row_2) seed_extras.append(seed_extra_row_2)
seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from width", value=0) seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from width", value=0)
seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=64, label="Resize seed from height", value=0) seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=1, label="Resize seed from height", value=0)
random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed])
random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed])
@ -715,8 +715,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Group(): with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@ -724,8 +724,8 @@ def create_ui(wrap_gradio_gpu_call):
enable_hr = gr.Checkbox(label='Highres. fix', value=False) enable_hr = gr.Checkbox(label='Highres. fix', value=False)
with gr.Row(visible=False) as hr_options: with gr.Row(visible=False) as hr_options:
firstphase_width = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass width", value=0) firstphase_width = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass width", value=0)
firstphase_height = gr.Slider(minimum=0, maximum=1024, step=64, label="Firstpass height", value=0) firstphase_height = gr.Slider(minimum=0, maximum=1024, step=1, label="Firstpass height", value=0)
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7)
with gr.Row(equal_height=True): with gr.Row(equal_height=True):
@ -901,8 +901,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group(): with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width") width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height") height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512, elem_id="img2img_height")
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
@ -1231,8 +1231,8 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Preprocess images"): with gr.Tab(label="Preprocess images"):
process_src = gr.Textbox(label='Source directory') process_src = gr.Textbox(label='Source directory')
process_dst = gr.Textbox(label='Destination directory') process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) process_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) process_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row(): with gr.Row():
@ -1289,8 +1289,8 @@ def create_ui(wrap_gradio_gpu_call):
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=1, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=1, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)