save parameters for images when using the Save button.

This commit is contained in:
AUTOMATIC 2022-09-28 17:05:23 +03:00
parent 5eb9d1aeac
commit aea5b2510e
3 changed files with 18 additions and 9 deletions

View File

@ -100,7 +100,7 @@ class StableDiffusionProcessing:
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
self.images = images_list self.images = images_list
self.prompt = p.prompt self.prompt = p.prompt
self.negative_prompt = p.negative_prompt self.negative_prompt = p.negative_prompt
@ -139,6 +139,7 @@ class Processed:
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
self.all_subseeds = all_subseeds or [self.subseed] self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self): def js(self):
obj = { obj = {
@ -165,6 +166,7 @@ class Processed:
"denoising_strength": self.denoising_strength, "denoising_strength": self.denoising_strength,
"extra_generation_params": self.extra_generation_params, "extra_generation_params": self.extra_generation_params,
"index_of_first_image": self.index_of_first_image, "index_of_first_image": self.index_of_first_image,
"infotexts": self.infotexts,
} }
return json.dumps(obj) return json.dumps(obj)
@ -322,6 +324,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if os.path.exists(cmd_opts.embeddings_dir): if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
infotexts = []
output_images = [] output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
@ -404,6 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.samples_save and not p.do_not_save_samples: if opts.samples_save and not p.do_not_save_samples:
images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
infotexts.append(infotext(n, i))
output_images.append(image) output_images.append(image)
state.nextjob() state.nextjob()
@ -416,6 +420,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size) grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid: if opts.return_grid:
infotexts.insert(0, infotext())
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
@ -423,7 +428,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

View File

@ -143,6 +143,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-paths', "Paths for saving"), {
@ -180,7 +181,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {

View File

@ -12,7 +12,7 @@ import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image, PngImagePlugin
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
@ -97,10 +97,11 @@ def save_files(js_data, images, index):
filenames = [] filenames = []
data = json.loads(js_data) data = json.loads(js_data)
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]] images = [images[index]]
data["seed"] += (index - 1 if opts.return_grid else index) infotexts = [data["infotexts"][index]]
else:
infotexts = data["infotexts"]
with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file:
at_start = file.tell() == 0 at_start = file.tell() == 0
@ -116,8 +117,11 @@ def save_files(js_data, images, index):
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
with open(filepath, "wb") as imgfile: pnginfo = PngImagePlugin.PngInfo()
imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) pnginfo.add_text('parameters', infotexts[i])
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
filenames.append(filename) filenames.append(filename)