added preview option

This commit is contained in:
AUTOMATIC 2022-09-06 19:33:51 +03:00
parent db6db585eb
commit fd66199769
7 changed files with 102 additions and 12 deletions

View File

@ -176,6 +176,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
if state.interrupted:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

View File

@ -42,6 +42,8 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
state.current_latent = x_dec
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
@ -141,6 +143,9 @@ class KDiffusionSampler:
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def callback_state(self, d):
state.current_latent = d["denoised"]
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps) t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps) sigmas = self.model_wrap.get_sigmas(p.steps)
@ -157,7 +162,7 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'): if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
def sample(self, p, x, conditioning, unconditional_conditioning): def sample(self, p, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps) sigmas = self.model_wrap.get_sigmas(p.steps)
@ -166,6 +171,6 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'): if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
return samples_ddim return samples_ddim

View File

@ -39,6 +39,7 @@ gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
class State: class State:
interrupted = False interrupted = False
job = "" job = ""
@ -46,6 +47,8 @@ class State:
job_count = 0 job_count = 0
sampling_step = 0 sampling_step = 0
sampling_steps = 0 sampling_steps = 0
current_latent = None
current_image = None
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
@ -99,6 +102,7 @@ class Options:
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}), "upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
} }
def __init__(self): def __init__(self):

View File

@ -9,6 +9,8 @@ import sys
import time import time
import traceback import traceback
import numpy as np
import torch
from PIL import Image from PIL import Image
import gradio as gr import gradio as gr
@ -119,6 +121,9 @@ def wrap_gradio_call(func):
print("Arguments:", args, kwargs, file=sys.stderr) print("Arguments:", args, kwargs, file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
@ -134,11 +139,9 @@ def wrap_gradio_call(func):
def check_progress_call(): def check_progress_call():
if not opts.show_progressbar:
return ""
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "" return "", gr_show(False), gr_show(False)
progress = 0 progress = 0
@ -149,9 +152,29 @@ def check_progress_call():
progress = min(progress, 1) progress = min(progress, 1)
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>""" progressbar = ""
if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>" image = gr_show(False)
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if (shared.state.sampling_step-1) % opts.show_progress_every_n_steps == 0 and shared.state.current_latent is not None:
x_sample = shared.sd_model.decode_first_stage(shared.state.current_latent[0:1].type(shared.sd_model.dtype))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
shared.state.current_image = Image.fromarray(x_sample)
image = shared.state.current_image
if image is None or progress >= 1:
image = gr.update(value=None)
else:
preview_visibility = gr_show(True)
return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
def roll_artist(prompt): def roll_artist(prompt):
@ -204,6 +227,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Group(): with gr.Group():
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery') txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery')
@ -251,8 +275,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
check_progress.click( check_progress.click(
fn=check_progress_call, fn=check_progress_call,
show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar], outputs=[progressbar, txt2img_preview, txt2img_preview],
) )
@ -337,13 +362,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Group(): with gr.Group():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery') img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery')
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
interrupt = gr.Button('Interrupt')
save = gr.Button('Save') save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras') img2img_send_to_extras = gr.Button('Send to extras')
interrupt = gr.Button('Interrupt')
progressbar = gr.HTML(elem_id="progressbar") progressbar = gr.HTML(elem_id="progressbar")
@ -426,8 +454,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
check_progress.click( check_progress.click(
fn=check_progress_call, fn=check_progress_call,
show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar], outputs=[progressbar, img2img_preview, img2img_preview],
) )
interrupt.click( interrupt.click(
@ -463,6 +492,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )
img2img_send_to_img2img.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[img2img_gallery],
outputs=[init_img],
)
img2img_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[img2img_gallery],
outputs=[init_img_with_mask],
)
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):

View File

@ -79,6 +79,23 @@ function addTitles(root){
global_progressbar = progressbar global_progressbar = progressbar
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
txt2img_preview = gradioApp().getElementById('txt2img_preview')
txt2img_gallery = gradioApp().getElementById('txt2img_gallery')
img2img_preview = gradioApp().getElementById('img2img_preview')
img2img_gallery = gradioApp().getElementById('img2img_gallery')
if(txt2img_preview != null && txt2img_gallery != null){
txt2img_preview.style.width = txt2img_gallery.clientWidth + "px"
txt2img_preview.style.height = txt2img_gallery.clientHeight + "px"
}
if(img2img_preview != null && img2img_gallery != null){
img2img_preview.style.width = img2img_gallery.clientWidth + "px"
img2img_preview.style.height = img2img_gallery.clientHeight + "px"
}
window.setTimeout(requestProgress, 500) window.setTimeout(requestProgress, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })

View File

@ -31,6 +31,20 @@ button{
max-width: 10em; max-width: 10em;
} }
#txt2img_preview, #img2img_preview{
position: absolute;
width: 320px;
left: 0;
right: 0;
margin-left: auto;
margin-right: auto;
z-index: 100;
}
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
display: none;
}
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute; position: absolute;
top: -0.6em; top: -0.6em;
@ -96,3 +110,4 @@ input[type="range"]{
text-align: right; text-align: right;
border-radius: 8px; border-radius: 8px;
} }

View File

@ -125,7 +125,8 @@ def wrap_gradio_gpu_call(func):
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
shared.state.current_latent = None
shared.state.current_image = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
@ -163,7 +164,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
if __name__ == "__main__": if __name__ == "__main__":
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with singal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
os._exit(0) os._exit(0)