Merge pull request #5976 from AbstractQbit/fast_preview

Add an option for faster low quality previews
This commit is contained in:
AUTOMATIC1111 2022-12-24 18:38:42 +03:00 committed by GitHub
commit a6a54a7529
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 9 deletions

View File

@ -106,20 +106,29 @@ def setup_img2img_steps(p, steps=None):
return steps, t_enc return steps, t_enc
def single_sample_to_image(sample): def single_sample_to_image(sample, approximation=False):
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] if approximation:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2
coefs = torch.tensor(
[[ 0.298, 0.207, 0.208],
[ 0.187, 0.286, 0.173],
[-0.158, 0.189, 0.264],
[-0.184, -0.271, -0.473]]).to(sample.device)
x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)
else:
x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
return Image.fromarray(x_sample) return Image.fromarray(x_sample)
def sample_to_image(samples, index=0): def sample_to_image(samples, index=0, approximation=False):
return single_sample_to_image(samples[index]) return single_sample_to_image(samples[index], approximation)
def samples_to_image_grid(samples): def samples_to_image_grid(samples, approximation=False):
return images.image_grid([single_sample_to_image(sample) for sample in samples]) return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples])
def store_latent(decoded): def store_latent(decoded):
@ -127,7 +136,7 @@ def store_latent(decoded):
if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0:
if not shared.parallel_processing_allowed: if not shared.parallel_processing_allowed:
shared.state.current_image = sample_to_image(decoded) shared.state.current_image = sample_to_image(decoded, approximation=opts.show_progress_approximate)
class InterruptedException(BaseException): class InterruptedException(BaseException):

View File

@ -212,9 +212,9 @@ class State:
import modules.sd_samplers import modules.sd_samplers
if opts.show_progress_grid: if opts.show_progress_grid:
self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent, approximation=opts.show_progress_approximate)
else: else:
self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) self.current_image = modules.sd_samplers.sample_to_image(self.current_latent, approximation=opts.show_progress_approximate)
self.current_image_sampling_step = self.sampling_step self.current_image_sampling_step = self.sampling_step
@ -392,6 +392,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
"show_progress_approximate": OptionInfo(False, "Calculate small previews using fast linear approximation instead of VAE"),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),