From c84e333622883be7f1c2efa08259d36d27cadec7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 13 Sep 2022 12:51:57 +0300 Subject: [PATCH] color correction option for all img2img modes #363 --- modules/img2img.py | 22 ---------------------- modules/processing.py | 28 ++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index e3109121..70c99e33 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,5 +1,4 @@ import math -import cv2 import numpy as np from PIL import Image, ImageOps, ImageChops @@ -76,18 +75,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init state.job_count = n_iter - do_color_correction = False - try: - from skimage import exposure - do_color_correction = True - except: - print("Install scikit-image to perform color correction on loopback") - - for i in range(n_iter): - if do_color_correction and i == 0: - correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB) - p.n_iter = 1 p.batch_size = 1 p.do_not_save_grid = True @@ -101,16 +89,6 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init init_img = processed.images[0] - if do_color_correction and correction_target is not None: - init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms( - cv2.cvtColor( - np.asarray(init_img), - cv2.COLOR_RGB2LAB - ), - correction_target, - channel_axis=2 - ), cv2.COLOR_LAB2RGB).astype("uint8")) - p.init_images = [init_img] p.seed = processed.seed + 1 p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) diff --git a/modules/processing.py b/modules/processing.py index 65ae4846..c0b89244 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -8,6 +8,8 @@ import torch import numpy as np from PIL import Image, ImageFilter, ImageOps import random +import cv2 +from skimage import exposure import modules.sd_hijack from modules import devices @@ -19,11 +21,30 @@ import modules.face_restoration import modules.images as images import modules.styles + # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 +def setup_color_correction(image): + correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB) + return correction_target + + +def apply_color_correction(correction, image): + image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + cv2.cvtColor( + np.asarray(image), + cv2.COLOR_RGB2LAB + ), + correction, + channel_axis=2 + ), cv2.COLOR_LAB2RGB).astype("uint8")) + + return image + + class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): self.sd_model = sd_model @@ -52,6 +73,7 @@ class StableDiffusionProcessing: self.extra_generation_params: dict = extra_generation_params self.overlay_images = overlay_images self.paste_to = None + self.color_corrections = None def init(self, seed): pass @@ -265,6 +287,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.color_corrections is not None and i < len(p.color_corrections): + image = apply_color_correction(p.color_corrections[i], image) if p.overlay_images is not None and i < len(p.overlay_images): overlay = p.overlay_images[i] @@ -420,6 +444,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask + self.color_corrections = [] imgs = [] for img in self.init_images: image = img.convert("RGB") @@ -441,6 +466,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_fill != 1: image = fill(image, latent_mask) + if opts.img2img_color_correction: + self.color_corrections.append(setup_color_correction(image)) + image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) diff --git a/modules/shared.py b/modules/shared.py index fd6b7fb4..594712d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -122,6 +122,7 @@ class Options: "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "add_model_hash_to_info": OptionInfo(False, "Add model hash to generation information"), + "img2img_color_correction": OptionInfo(True, "Apply color correction to img2img results to match original colors."), "font": OptionInfo("", "Font for image grids that have text"), "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),