From 269833067de1e7d0b6a6bd65724743d6b88a133f Mon Sep 17 00:00:00 2001 From: Kyle Date: Thu, 2 Feb 2023 09:37:01 -0500 Subject: [PATCH 1/6] instruct-pix2pix support --- modules/processing.py | 2 +- modules/sd_samplers_kdiffusion.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index e544c2e1..f299e04d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -186,7 +186,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.encode_first_stage(source_image).mode() return conditioning_image diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index aa7f106b..31ee22d3 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module): batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) + cond_in = torch.cat([tensor, uncond, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) From cf0cfefe910b0de18c4751ce8d8cf7a6053a39b0 Mon Sep 17 00:00:00 2001 From: Kyle Date: Thu, 2 Feb 2023 19:15:38 -0500 Subject: [PATCH 2/6] Revert "instruct-pix2pix support" This reverts commit 269833067de1e7d0b6a6bd65724743d6b88a133f. --- modules/processing.py | 2 +- modules/sd_samplers_kdiffusion.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f299e04d..e544c2e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -186,7 +186,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) return conditioning_image diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 31ee22d3..aa7f106b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module): batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond]) + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) From 3b2ad20ac1753cb664bd8954dd34f0c04d3678c2 Mon Sep 17 00:00:00 2001 From: Kyle Date: Thu, 2 Feb 2023 19:19:45 -0500 Subject: [PATCH 3/6] Processing only, no CFGDenoiser change Allows instruct-pix2pix --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index e544c2e1..f299e04d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -186,7 +186,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.encode_first_stage(source_image).mode() return conditioning_image From 6c6c6636bb123d664999c888cda47a1f8bad635b Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 18:19:56 -0500 Subject: [PATCH 4/6] Image CFG Added (Full Implementation) Uses separate denoiser for edit (instruct-pix2pix) models No impact to txt2img or regular img2img "Image CFG Scale" will only apply to instruct-pix2pix models and metadata will only be added if using such model --- modules/img2img.py | 3 +- modules/processing.py | 4 +- modules/sd_samplers_kdiffusion.py | 101 ++++++++++++++++++++++++++++-- modules/ui.py | 3 + 4 files changed, 103 insertions(+), 8 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index f813299c..bcc158dc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -76,7 +76,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -142,6 +142,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s inpainting_fill=inpainting_fill, resize_mode=resize_mode, denoising_strength=denoising_strength, + image_cfg_scale=image_cfg_scale, inpaint_full_res=inpaint_full_res, inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, diff --git a/modules/processing.py b/modules/processing.py index f299e04d..c33694cc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -445,6 +445,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Steps": p.steps, "Sampler": p.sampler_name, "CFG scale": p.cfg_scale, + "Image CFG scale": getattr(p, 'image_cfg_scale', None), "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", @@ -901,12 +902,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): + def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): super().__init__(**kwargs) self.init_images = init_images self.resize_mode: int = resize_mode self.denoising_strength: float = denoising_strength + self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.init_latent = None self.image_mask = mask self.latent_mask = None diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index aa7f106b..a16ba69b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,6 +1,7 @@ from collections import deque import torch import inspect +import einops import k_diffusion.sampling from modules import prompt_parser, devices, sd_samplers_common @@ -40,6 +41,90 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } +class CFGDenoiserEdit(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index]) + + return denoised + + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + + if tensor.shape[1] == uncond.shape[1]: + cond_in = torch.cat([tensor, uncond, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]}) + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + + devices.test_for_nans(x_out, "unet") + + if opts.live_preview_content == "Prompt": + sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) + + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale) + + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + self.step += 1 + + return denoised + class CFGDenoiser(torch.nn.Module): """ @@ -78,8 +163,8 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -160,7 +245,7 @@ class KDiffusionSampler: self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) if not shared.sd_model.cond_stage_key == "edit" else CFGDenoiserEdit(self.model_wrap) self.sampler_noises = None self.stop_at = None self.eta = None @@ -260,13 +345,17 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + extra_args={ 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + 'cond_scale': p.cfg_scale, + } + + if p.image_cfg_scale: + extra_args['image_cfg_scale'] = p.image_cfg_scale + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples diff --git a/modules/ui.py b/modules/ui.py index 5e34fb07..f2f7de8b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -766,6 +766,7 @@ def create_ui(): elif category == "cfg": with FormGroup(): cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale (for instruct-pix2pix models only)', value=1.5, elem_id="img2img_image_cfg_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") elif category == "seed": @@ -861,6 +862,7 @@ def create_ui(): batch_count, batch_size, cfg_scale, + image_cfg_scale, denoising_strength, seed, subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, @@ -947,6 +949,7 @@ def create_ui(): (sampler_index, "Sampler"), (restore_faces, "Face restoration"), (cfg_scale, "CFG scale"), + (image_cfg_scale, "Image CFG scale"), (seed, "Seed"), (width, "Size-1"), (height, "Size-2"), From c27c0de0f73c5f533acfa10426dbac7ac988bc85 Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 19:15:32 -0500 Subject: [PATCH 5/6] txt2img Hires Fix --- modules/processing.py | 1 + modules/sd_samplers_kdiffusion.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index c33694cc..e1b53ac0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -268,6 +268,7 @@ class Processed: self.height = p.height self.sampler_name = p.sampler_name self.cfg_scale = p.cfg_scale + self.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.steps = p.steps self.batch_size = p.batch_size self.restore_faces = p.restore_faces diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a16ba69b..6107e99e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -352,7 +352,7 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if p.image_cfg_scale: + if hasattr(p, 'image_cfg_scale'): extra_args['image_cfg_scale'] = p.image_cfg_scale samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) From ba6a4e7e9431d02ba3656c6ae44d5dfe29908d68 Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 19:46:13 -0500 Subject: [PATCH 6/6] Use original CFGDenoiser if image_cfg_scale = 1 If image_cfg_scale is =1 then the original image is not used for the output. We can then use the original CFGDenoiser to get the same result to support AND functionality. Maybe in the future AND can be supported with "Image CFG Scale" --- modules/sd_samplers_kdiffusion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6107e99e..6c57fdec 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -245,7 +245,7 @@ class KDiffusionSampler: self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) if not shared.sd_model.cond_stage_key == "edit" else CFGDenoiserEdit(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.stop_at = None self.eta = None @@ -280,6 +280,9 @@ class KDiffusionSampler: return p.steps def initialize(self, p): + if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1: + self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap) + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 @@ -352,7 +355,7 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if hasattr(p, 'image_cfg_scale'): + if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None: extra_args['image_cfg_scale'] = p.image_cfg_scale samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))