diff --git a/modules/masking.py b/modules/masking.py new file mode 100644 index 00000000..fd8d9241 --- /dev/null +++ b/modules/masking.py @@ -0,0 +1,99 @@ +from PIL import Image, ImageFilter, ImageOps + + +def get_crop_region(mask, pad=0): + """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. + For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" + + h, w = mask.shape + + crop_left = 0 + for i in range(w): + if not (mask[:, i] == 0).all(): + break + crop_left += 1 + + crop_right = 0 + for i in reversed(range(w)): + if not (mask[:, i] == 0).all(): + break + crop_right += 1 + + crop_top = 0 + for i in range(h): + if not (mask[i] == 0).all(): + break + crop_top += 1 + + crop_bottom = 0 + for i in reversed(range(h)): + if not (mask[i] == 0).all(): + break + crop_bottom += 1 + + return ( + int(max(crop_left-pad, 0)), + int(max(crop_top-pad, 0)), + int(min(w - crop_right + pad, w)), + int(min(h - crop_bottom + pad, h)) + ) + + +def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): + """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region + for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" + + x1, y1, x2, y2 = crop_region + + ratio_crop_region = (x2 - x1) / (y2 - y1) + ratio_processing = processing_width / processing_height + + if ratio_crop_region > ratio_processing: + desired_height = (x2 - x1) * ratio_processing + desired_height_diff = int(desired_height - (y2-y1)) + y1 -= desired_height_diff//2 + y2 += desired_height_diff - desired_height_diff//2 + if y2 >= image_height: + diff = y2 - image_height + y2 -= diff + y1 -= diff + if y1 < 0: + y2 -= y1 + y1 -= y1 + if y2 >= image_height: + y2 = image_height + else: + desired_width = (y2 - y1) * ratio_processing + desired_width_diff = int(desired_width - (x2-x1)) + x1 -= desired_width_diff//2 + x2 += desired_width_diff - desired_width_diff//2 + if x2 >= image_width: + diff = x2 - image_width + x2 -= diff + x1 -= diff + if x1 < 0: + x2 -= x1 + x1 -= x1 + if x2 >= image_width: + x2 = image_width + + return x1, y1, x2, y2 + + +def fill(image, mask): + """fills masked regions with colors from image using blur. Not extremely effective.""" + + image_mod = Image.new('RGBA', (image.width, image.height)) + + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) + + image_masked = image_masked.convert('RGBa') + + for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: + blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') + for _ in range(repeats): + image_mod.alpha_composite(blurred) + + return image_mod.convert("RGB") + diff --git a/modules/processing.py b/modules/processing.py index 147d64e3..1afbe39c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser +from modules import devices, prompt_parser, masking from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state @@ -365,58 +365,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples_ddim - -def get_crop_region(mask, pad=0): - h, w = mask.shape - - crop_left = 0 - for i in range(w): - if not (mask[:, i] == 0).all(): - break - crop_left += 1 - - crop_right = 0 - for i in reversed(range(w)): - if not (mask[:, i] == 0).all(): - break - crop_right += 1 - - crop_top = 0 - for i in range(h): - if not (mask[i] == 0).all(): - break - crop_top += 1 - - crop_bottom = 0 - for i in reversed(range(h)): - if not (mask[i] == 0).all(): - break - crop_bottom += 1 - - return ( - int(max(crop_left-pad, 0)), - int(max(crop_top-pad, 0)), - int(min(w - crop_right + pad, w)), - int(min(h - crop_bottom + pad, h)) - ) - - -def fill(image, mask): - image_mod = Image.new('RGBA', (image.width, image.height)) - - image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) - - image_masked = image_masked.convert('RGBa') - - for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: - blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') - for _ in range(repeats): - image_mod.alpha_composite(blurred) - - return image_mod.convert("RGB") - - class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None @@ -456,7 +404,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpaint_full_res: self.mask_for_overlay = self.image_mask mask = self.image_mask.convert('L') - crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding) + crop_region = masking.get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding) + crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) x1, y1, x2, y2 = crop_region mask = mask.crop(crop_region) @@ -494,7 +443,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None: if self.inpainting_fill != 1: - image = fill(image, latent_mask) + image = masking.fill(image, latent_mask) if add_color_corrections: self.color_corrections.append(setup_color_correction(image))