fix #4459 breaking inpainting when the option is not specified.

This commit is contained in:
AUTOMATIC 2022-12-04 01:04:24 +03:00
parent cefb5d6d7d
commit 8504db5170
2 changed files with 23 additions and 19 deletions

View File

@ -4,7 +4,7 @@ import sys
import traceback import traceback
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageFilter, ImageEnhance from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
from modules import devices, sd_samplers from modules import devices, sd_samplers
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
@ -66,22 +66,23 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if is_inpaint: if is_inpaint:
# Drawn mask # Drawn mask
if mask_mode == 0: if mask_mode == 0:
image = init_img_with_mask is_mask_sketch = isinstance(init_img_with_mask, dict)
is_mask_sketch = isinstance(image, dict)
is_mask_paint = not is_mask_sketch is_mask_paint = not is_mask_sketch
if is_mask_sketch: if is_mask_sketch:
# Sketch: mask iff. not transparent # Sketch: mask iff. not transparent
image, mask = image["image"], image["mask"] image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
pred = np.array(mask)[..., -1] > 0 alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
else: else:
# Color-sketch: mask iff. painted over # Color-sketch: mask iff. painted over
orig = init_img_with_mask_orig or image image = init_img_with_mask
orig = init_img_with_mask_orig or init_img_with_mask
pred = np.any(np.array(image) != np.array(orig), axis=-1) pred = np.any(np.array(image) != np.array(orig), axis=-1)
mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") mask = Image.fromarray(pred.astype(np.uint8) * 255, "L")
if is_mask_paint:
mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100)
blur = ImageFilter.GaussianBlur(mask_blur) blur = ImageFilter.GaussianBlur(mask_blur)
image = Image.composite(image.filter(blur), orig, mask.filter(blur)) image = Image.composite(image.filter(blur), orig, mask.filter(blur))
image = image.convert("RGB") image = image.convert("RGB")
# Uploaded mask # Uploaded mask
else: else:

View File

@ -791,23 +791,26 @@ def create_ui():
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480) init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool).style(height=480)
with gr.TabItem('Inpaint', id='inpaint'): with gr.TabItem('Inpaint', id='inpaint'):
init_img_with_mask_orig = gr.State(None)
init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480)
init_img_with_mask_orig = gr.State(None)
def update_orig(image, state): use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch"
if image is not None: if use_color_sketch:
same_size = state is not None and state.size == image.size def update_orig(image, state):
has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) if image is not None:
edited = same_size and has_exact_match same_size = state is not None and state.size == image.size
return image if not edited or state is None else state has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1))
edited = same_size and has_exact_match
return image if not edited or state is None else state
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig)
init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base")
init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask")
show_mask_alpha = cmd_opts.gradio_inpaint_tool == "color-sketch" with gr.Row():
mask_alpha = gr.Slider(label="Mask transparency", interactive=show_mask_alpha, visible=show_mask_alpha) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch)
with gr.Row(): with gr.Row():
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")