diff --git a/.gitignore b/.gitignore index 554c965e..3532dab3 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,8 @@ __pycache__ /tmp /model.ckpt /models/**/* +/GFPGANv1.3.pth +/gfpgan/weights/*.pth /ui-config.json /outputs /config.json diff --git a/javascript/hints.js b/javascript/hints.js index 59dd770c..96cd24d5 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -15,6 +15,7 @@ titles = { "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.", + "\uD83D\uDCC2": "Open images output directory", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 7db4db48..562d2552 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -182,4 +182,23 @@ onUiUpdate(function(){ }); json_elem.parentElement.style.display="none" + + if (!txt2img_textarea) { + txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); + txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); + } + if (!img2img_textarea) { + img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); + img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); + } }) + +let txt2img_textarea, img2img_textarea = undefined; +let wait_time = 800 +let token_timeout; + +function update_token_counter(button_id) { + if (token_timeout) + clearTimeout(token_timeout); + token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); +} diff --git a/launch.py b/launch.py index 77edea26..3b8d8f23 100644 --- a/launch.py +++ b/launch.py @@ -15,11 +15,11 @@ torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") -k_diffusion_package = os.environ.get('K_DIFFUSION_PACKAGE', "git+https://github.com/crowsonkb/k-diffusion.git@1a0703dfb7d24d8806267c3e7ccc4caf67fd1331") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") +k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -107,10 +107,7 @@ if not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: - run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDINE_ARGS variable to disable this check'") - -if not is_installed("k_diffusion.sampling"): - run_pip(f"install {k_diffusion_package}", "k-diffusion") + run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") if not is_installed("gfpgan"): run_pip(f"install {gfpgan_package}", "gfpgan") @@ -119,6 +116,7 @@ os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) +git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) if os.path.isdir(repo_dir('latent-diffusion')): @@ -133,6 +131,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI") sys.argv += args +if "--exit" in args: + print("Exiting because of --exit argument") + exit(0) def start_webui(): print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") diff --git a/modules/extras.py b/modules/extras.py index d7d0fa54..1d4e9fa8 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -6,13 +6,14 @@ from PIL import Image import torch import tqdm -from modules import processing, shared, images, devices +from modules import processing, shared, images, devices, sd_models from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html import modules.codeformer_model import piexif import piexif.helper +import gradio as gr cached_images = {} @@ -141,7 +142,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount): +def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -151,45 +152,52 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount): alpha = alpha * alpha * (3 - (2 * alpha)) return theta0 + ((theta1 - theta0) * alpha) - if os.path.exists(modelname_0): - model0_filename = modelname_0 - modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0] - else: - model0_filename = 'models/' + modelname_0 + '.ckpt' + # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) + def inv_sigmoid(theta0, theta1, alpha): + import math + alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) + return theta0 + ((theta1 - theta0) * alpha) - if os.path.exists(modelname_1): - model1_filename = modelname_1 - modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0] - else: - model1_filename = 'models/' + modelname_1 + '.ckpt' + primary_model_info = sd_models.checkpoints_list[primary_model_name] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - print(f"Loading {model0_filename}...") - model_0 = torch.load(model0_filename, map_location='cpu') + print(f"Loading {primary_model_info.filename}...") + primary_model = torch.load(primary_model_info.filename, map_location='cpu') - print(f"Loading {model1_filename}...") - model_1 = torch.load(model1_filename, map_location='cpu') - - theta_0 = model_0['state_dict'] - theta_1 = model_1['state_dict'] + print(f"Loading {secondary_model_info.filename}...") + secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') + + theta_0 = primary_model['state_dict'] + theta_1 = secondary_model['state_dict'] theta_funcs = { "Weighted Sum": weighted_sum, "Sigmoid": sigmoid, + "Inverse Sigmoid": inv_sigmoid, } theta_func = theta_funcs[interp_method] print(f"Merging...") for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount) + theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + if save_as_half: + theta_0[key] = theta_0[key].half() for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] + if save_as_half: + theta_0[key] = theta_0[key].half() + + filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = filename if custom_name == '' else (custom_name + '.ckpt') + output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) - output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt' print(f"Saving to {output_modelname}...") - torch.save(model_0, output_modelname) + torch.save(primary_model, output_modelname) + + sd_models.list_models() print(f"Checkpoint saved.") - return "Checkpoint saved to " + output_modelname + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] diff --git a/modules/img2img.py b/modules/img2img.py index d80b3e75..03e934e9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -124,4 +124,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if opts.samples_log_stdout: print(generation_info_js) - return processed.images, generation_info_js, plaintext_to_html(processed.info) \ No newline at end of file + return processed.images, generation_info_js, plaintext_to_html(processed.info) diff --git a/modules/paths.py b/modules/paths.py index 015fa672..0f9f6a56 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -21,6 +21,7 @@ path_dirs = [ (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'), + (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'), ] paths = {} diff --git a/modules/processing.py b/modules/processing.py index 8d043f4d..4ecdfcd2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -49,7 +49,7 @@ def apply_color_correction(correction, image): class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -75,15 +75,15 @@ class StableDiffusionProcessing: self.do_not_save_grid: bool = do_not_save_grid self.extra_generation_params: dict = extra_generation_params or {} self.overlay_images = overlay_images + self.eta = eta self.paste_to = None self.color_corrections = None self.denoising_strength: float = 0 - - self.ddim_eta = opts.ddim_eta + self.ddim_discretize = opts.ddim_discretize self.s_churn = opts.s_churn self.s_tmin = opts.s_tmin - self.s_tmax = float('inf') # not representable as a standard ui option + self.s_tmax = float('inf') # not representable as a standard ui option self.s_noise = opts.s_noise if not seed_enable_extras: @@ -100,7 +100,7 @@ class StableDiffusionProcessing: class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -124,7 +124,7 @@ class Processed: self.extra_generation_params = p.extra_generation_params self.index_of_first_image = index_of_first_image - self.ddim_eta = p.ddim_eta + self.eta = p.eta self.ddim_discretize = p.ddim_discretize self.s_churn = p.s_churn self.s_tmin = p.s_tmin @@ -139,6 +139,7 @@ class Processed: self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] self.all_subseeds = all_subseeds or [self.subseed] + self.infotexts = infotexts or [info] def js(self): obj = { @@ -165,6 +166,7 @@ class Processed: "denoising_strength": self.denoising_strength, "extra_generation_params": self.extra_generation_params, "index_of_first_image": self.index_of_first_image, + "infotexts": self.infotexts, } return json.dumps(obj) @@ -269,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), + "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -277,7 +280,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" - return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) + return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() def process_images(p: StableDiffusionProcessing) -> Processed: @@ -322,6 +325,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if os.path.exists(cmd_opts.embeddings_dir): model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + infotexts = [] output_images = [] precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope) @@ -404,6 +408,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if opts.samples_save and not p.do_not_save_samples: images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) + infotexts.append(infotext(n, i)) output_images.append(image) state.nextjob() @@ -416,6 +421,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: + infotexts.insert(0, infotext()) output_images.insert(0, grid) index_of_first_image = 1 @@ -423,7 +429,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext(), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image) + return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a6a25b28..e811eb9e 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +re_attention = re.compile(r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", re.X) -#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100) + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + + Example: + + 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' + + produces: + + [ + ['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1] + ] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith('\\'): + res.append([text[1:], 1.0]) + elif text == '(': + round_brackets.append(len(res)) + elif text == '[': + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ')' and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == ']' and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + return res diff --git a/modules/scripts.py b/modules/scripts.py index 202374e6..7c3bd5e7 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -55,7 +55,7 @@ def load_scripts(basedir): if not os.path.exists(basedir): return - for filename in os.listdir(basedir): + for filename in sorted(os.listdir(basedir)): path = os.path.join(basedir, filename) if not os.path.isfile(path): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7b2030d4..5945b7c2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,6 +6,7 @@ import torch import numpy as np from torch import einsum +from modules import prompt_parser from modules.shared import opts, device, cmd_opts from ldm.util import default @@ -180,6 +181,7 @@ class StableDiffusionModelHijack: dir_mtime = None layers = None circular_enabled = False + clip = None def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) @@ -210,6 +212,7 @@ class StableDiffusionModelHijack: param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] + # diffuser concepts elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: assert len(data.keys()) == 1, 'embedding file has multiple terms in it' @@ -235,7 +238,7 @@ class StableDiffusionModelHijack: print(traceback.format_exc(), file=sys.stderr) continue - print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -243,6 +246,8 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): @@ -259,6 +264,14 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + def undo_hijack(self, m): + if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + m.cond_stage_model = m.cond_stage_model.wrapped + + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + def apply_circular(self, enable): if self.circular_enabled == enable: return @@ -268,6 +281,11 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def tokenize(self, text): + max_length = self.clip.max_length - 2 + _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + return remade_batch_tokens[0], token_count, max_length + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): @@ -294,14 +312,101 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def forward(self, text): - self.hijack.fixes = [] - self.hijack.comments = [] + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + id_start = self.wrapped.tokenizer.bos_token_id + id_end = self.wrapped.tokenizer.eos_token_id + maxlen = self.wrapped.max_length + + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + + fixes = [] + remade_tokens = [] + multipliers = [] + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + possible_matches = self.hijack.ids_lookup.get(token, None) + + if possible_matches is None: + remade_tokens.append(token) + multipliers.append(weight) + else: + found = False + for ids, word in possible_matches: + if tokens[i:i + len(ids)] == ids: + emb_len = int(self.hijack.word_embeddings[word].shape[0]) + fixes.append((len(remade_tokens), word)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + i += len(ids) - 1 + found = True + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) + break + + if not found: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + + def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length used_custom_terms = [] + remade_batch_tokens = [] + overflowing_words = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -353,9 +458,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -364,11 +468,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.hijack.fixes.append(fixes) + hijack_fixes.append(fixes) batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + + if opts.use_old_emphasis_implementation: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + + self.hijack.fixes = hijack_fixes + self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens) diff --git a/modules/sd_models.py b/modules/sd_models.py index 23826727..7ed22c1e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,8 +15,9 @@ model_dir = "Stable-diffusion" model_path = os.path.join(models_path, model_dir) model_name = "sd-v1-4.ckpt" model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" +user_dir = None -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} try: @@ -47,23 +48,56 @@ def setup_model(dirname): global model_path global model_name global model_url + global user_dir + global model_list + user_dir = dirname if not os.path.exists(model_path): os.makedirs(model_path) checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt") + list_models() + + +def checkpoint_tiles(): + return sorted([x.title for x in checkpoints_list.values()]) + + +def list_models(): + global model_path + global model_url + global model_name + global user_dir + checkpoints_list.clear() + model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir, + ext_filter=[".ckpt"], download_name=model_name) + print(f"Model list: {model_list}") + model_dir = os.path.abspath(model_path) + + def modeltitle(path, h): + abspath = os.path.abspath(path) + + if abspath.startswith(model_dir): + name = abspath.replace(model_dir, '') + else: + name = os.path.basename(path) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + + return f'{name} [{h}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) - title = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h) + title, model_name = modeltitle(cmd_ckpt, h) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in model_list: h = model_hash(filename) title = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h) + checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) def model_hash(filename): @@ -89,7 +123,7 @@ def select_checkpoint(): if len(checkpoints_list) == 0: print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) + print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) @@ -142,7 +176,7 @@ def load_model(): def reload_model_weights(sd_model, info=None): - from modules import lowvram, devices + from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if sd_model.sd_model_checkpint == checkpoint_info.filename: @@ -153,8 +187,12 @@ def reload_model_weights(sd_model, info=None): else: sd_model.to(devices.cpu) + sd_hijack.model_hijack.undo_hijack(sd_model) + load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + sd_hijack.model_hijack.hijack(sd_model) + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index cfc3ee40..5642b870 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -3,6 +3,7 @@ import numpy as np import torch import tqdm from PIL import Image +import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim @@ -22,6 +23,8 @@ samplers_k_diffusion = [ ('Heun', 'sample_heun', ['k_heun']), ('DPM2', 'sample_dpm_2', ['k_dpm_2']), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), ] samplers_data_k_diffusion = [ @@ -35,12 +38,12 @@ samplers = [ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] -samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] +samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']] sampler_extra_params = { - 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'], - 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'], - 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'], + 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } def setup_img2img_steps(p, steps=None): @@ -98,6 +101,8 @@ class VanillaStableDiffusionSampler: self.init_latent = None self.sampler_noises = None self.step = 0 + self.eta = None + self.default_eta = 0.0 def number_of_needed_noises(self, p): return 0 @@ -120,20 +125,29 @@ class VanillaStableDiffusionSampler: self.step += 1 return res + def initialize(self, p): + self.eta = p.eta or opts.eta_ddim + + for fieldname in ['p_sample_ddim', 'p_sample_plms']: + if hasattr(self.sampler, fieldname): + setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + + self.mask = p.mask if hasattr(p, 'mask') else None + self.nmask = p.nmask if hasattr(p, 'nmask') else None + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): steps, t_enc = setup_img2img_steps(p, steps) + self.initialize(p) + # existing code fails with cetain step counts, like 9 try: - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False) + self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) except Exception: - self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False) + self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - self.sampler.p_sample_ddim = self.p_sample_ddim_hook - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None self.init_latent = x self.step = 0 @@ -142,11 +156,8 @@ class VanillaStableDiffusionSampler: return samples def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - self.mask = None - self.nmask = None + self.initialize(p) + self.init_latent = None self.step = 0 @@ -154,9 +165,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetin step counts, like 9 try: - samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta) + samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) except Exception: - samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta) + samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) return samples_ddim @@ -229,11 +240,13 @@ class KDiffusionSampler: self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname,[]) + self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.sampler_noise_index = 0 self.stop_at = None + self.eta = None + self.default_eta = 1.0 def callback_state(self, d): store_latent(d["denoised"]) @@ -252,22 +265,12 @@ class KDiffusionSampler: self.sampler_noise_index += 1 return res - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): - steps, t_enc = setup_img2img_steps(p, steps) - - sigmas = self.model_wrap.get_sigmas(steps) - - noise = noise * sigmas[steps - t_enc - 1] - - xi = x + noise - - sigma_sched = sigmas[steps - t_enc - 1:] - + def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.init_latent = x self.model_wrap.step = 0 self.sampler_noise_index = 0 + self.eta = p.eta or opts.eta_ancestral if hasattr(k_diffusion.sampling, 'trange'): k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs) @@ -276,9 +279,28 @@ class KDiffusionSampler: k_diffusion.sampling.torch = TorchHijack(self) extra_params_kwargs = {} - for val in self.extra_params: - if hasattr(p,val): - extra_params_kwargs[val] = getattr(p,val) + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + extra_params_kwargs['eta'] = self.eta + + return extra_params_kwargs + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + steps, t_enc = setup_img2img_steps(p, steps) + + sigmas = self.model_wrap.get_sigmas(steps) + + noise = noise * sigmas[steps - t_enc - 1] + xi = x + noise + + extra_params_kwargs = self.initialize(p) + + sigma_sched = sigmas[steps - t_enc - 1:] + + self.model_wrap_cfg.init_latent = x return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) @@ -288,21 +310,14 @@ class KDiffusionSampler: sigmas = self.model_wrap.get_sigmas(steps) x = x * sigmas[0] - self.model_wrap_cfg.step = 0 - self.sampler_noise_index = 0 - - if hasattr(k_diffusion.sampling, 'trange'): - k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs) - - if self.sampler_noises is not None: - k_diffusion.sampling.torch = TorchHijack(self) - - extra_params_kwargs = {} - for val in self.extra_params: - if hasattr(p,val): - extra_params_kwargs[val] = getattr(p,val) - - samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) - + extra_params_kwargs = self.initialize(p) + if 'sigma_min' in inspect.signature(self.func).parameters: + extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() + extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() + if 'n' in inspect.signature(self.func).parameters: + extra_params_kwargs['n'] = steps + else: + extra_params_kwargs['sigmas'] = sigmas + samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) return samples diff --git a/modules/shared.py b/modules/shared.py index 4c31039d..69002158 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -155,6 +155,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), + "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { @@ -182,7 +183,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), - "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -190,7 +190,6 @@ options_templates.update(options_section(('face-restoration', "Face restoration" "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), - "save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"), })) options_templates.update(options_section(('system', "System"), { @@ -200,12 +199,13 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), @@ -231,8 +231,9 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "ddim_eta": OptionInfo(0.0, "DDIM eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform','quad']}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/modules/ui.py b/modules/ui.py index e96109c9..ada9a38e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -9,10 +9,13 @@ import random import sys import time import traceback +import platform +import subprocess as sp import numpy as np import torch -from PIL import Image +from PIL import Image, PngImagePlugin +import piexif import gradio as gr import gradio.utils @@ -22,6 +25,7 @@ from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img +from modules.sd_hijack import model_hijack import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -60,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ - +folder_symbol = '\uD83D\uDCC2' def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" @@ -97,10 +101,11 @@ def save_files(js_data, images, index): filenames = [] data = json.loads(js_data) - - if index > -1 and opts.save_selected_only and (index > 0 or not opts.return_grid): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only images = [images[index]] - data["seed"] += (index - 1 if opts.return_grid else index) + infotexts = [data["infotexts"][index]] + else: + infotexts = data["infotexts"] with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: at_start = file.tell() == 0 @@ -109,15 +114,26 @@ def save_files(js_data, images, index): writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) filename_base = str(int(time.time() * 1000)) + extension = opts.samples_format.lower() for i, filedata in enumerate(images): - filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" + filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}" filepath = os.path.join(opts.outdir_save, filename) if filedata.startswith("data:image/png;base64,"): filedata = filedata[len("data:image/png;base64,"):] - with open(filepath, "wb") as imgfile: - imgfile.write(base64.decodebytes(filedata.encode('utf-8'))) + image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) + if opts.enable_pnginfo and extension == 'png': + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text('parameters', infotexts[i]) + image.save(filepath, pnginfo=pnginfo) + else: + image.save(filepath, quality=opts.jpeg_quality) + + if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"): + piexif.insert(piexif.dump({"Exif": { + piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode") + }}), filepath) filenames.append(filename) @@ -329,6 +345,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) +def update_token_counter(text): + tokens, token_count, max_length = model_hijack.tokenize(text) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" @@ -338,11 +358,14 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -449,6 +472,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): send_to_img2img = gr.Button('Send to img2img') send_to_inpaint = gr.Button('Send to inpaint') send_to_extras = gr.Button('Send to extras') + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' + open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id) with gr.Group(): html_info = gr.HTML() @@ -625,6 +650,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): img2img_send_to_img2img = gr.Button('Send to img2img') img2img_send_to_inpaint = gr.Button('Send to inpaint') img2img_send_to_extras = gr.Button('Send to extras') + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' + open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id) with gr.Group(): html_info = gr.HTML() @@ -797,6 +824,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): html_info = gr.HTML() extras_send_to_img2img = gr.Button('Send to img2img') extras_send_to_inpaint = gr.Button('Send to inpaint') + button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else '' + open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( fn=run_extras, @@ -857,30 +886,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your /models directory.

") + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)") - modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)") - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid"], value="Weighted Sum", label="Interpolation Method") + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") + custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) - submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") + save_as_half = gr.Checkbox(value=False, label="Safe as float16") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - submit.click( - fn=run_modelmerger, - inputs=[ - modelname_0, - modelname_1, - interp_method, - interp_amount - ], - outputs=[ - submit_result, - ] - ) - def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -904,6 +923,17 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): return comp(label=info.label, value=fun, **(args or {})) components = [] + component_dict = {} + + def open_folder(f): + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + else: + sp.Popen(["xdg-open", path]) def run_settings(*args): changed = 0 @@ -959,7 +989,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - components.append(create_setting_component(k)) + component = create_setting_component(k) + component_dict[k] = component + components.append(component) items_displayed += 1 request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") @@ -1009,7 +1041,34 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): inputs=components, outputs=[result, text_settings], ) + + def modelmerger(*args): + try: + results = run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() #To remove the potentially missing models from the list + return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] + return results + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2'] txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names] img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names] @@ -1048,6 +1107,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[extras_image], ) + open_txt2img_folder.click( + fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples), + inputs=[], + outputs=[], + ) + + open_img2img_folder.click( + fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples), + inputs=[], + outputs=[], + ) + + open_extras_folder.click( + fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples), + inputs=[], + outputs=[], + ) + img2img_send_to_extras.click( fn=lambda x: image_from_url_text(x), _js="extract_image_from_gallery_extras", diff --git a/requirements.txt b/requirements.txt index 08935506..7cb9d329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,9 +4,8 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio +gradio==3.4b3 invisible-watermark -git+https://github.com/crowsonkb/k-diffusion.git numpy omegaconf piexif @@ -16,5 +15,12 @@ realesrgan scikit-image>=0.19 git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379 timm==0.4.12 -transformers +transformers==4.19.2 torch +einops +jsonmerge +clean-fid +git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 +resize-right +torchdiffeq +kornia diff --git a/requirements_versions.txt b/requirements_versions.txt index 505498e7..1e8006e0 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -14,4 +14,11 @@ fonts font-roboto timm==0.6.7 fairscale==0.4.9 -piexif==1.1.3 \ No newline at end of file +piexif==1.1.3 +einops==0.4.1 +jsonmerge==1.8.0 +clean-fid==0.1.29 +git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1 +resize-right==0.0.2 +torchdiffeq==0.2.3 +kornia==0.6.7 diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7c01231f..f8bc64c4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -87,12 +87,12 @@ axis_options = [ AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), - AxisOption("DDIM Eta", float, apply_field("ddim_eta"), format_value_add_label), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),# as it is now all AxisOptionImg2Img items must go after AxisOption ones + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -159,6 +159,9 @@ class Script(scripts.Script): p.batch_size = 1 def process_axis(opt, vals): + if opt.label == 'Nothing': + return [0] + valslist = [x.strip() for x in vals.split(",")] if opt.type == int: diff --git a/style.css b/style.css index 4054e2df..9709c4ee 100644 --- a/style.css +++ b/style.css @@ -1,5 +1,11 @@ .output-html p {margin: 0 0.5em;} +.row > *, +.row > .gr-form > * { + min-width: min(120px, 100%); + flex: 1 1 0%; +} + .performance { font-size: 0.85em; color: #444; @@ -43,13 +49,17 @@ margin-right: auto; } -#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{ +#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ min-width: auto; flex-grow: 0; padding-left: 0.25em; padding-right: 0.25em; } +#hidden_element{ + display: none; +} + #seed_row, #subseed_row{ gap: 0.5rem; } @@ -389,3 +399,7 @@ input[type="range"]{ border-radius: 8px; display: none; } + +.red { + color: red; +} diff --git a/webui.py b/webui.py index be1bc769..5fd65edc 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,8 @@ import os +import threading + +from modules import devices +from modules.paths import script_path import signal import threading import modules.paths @@ -44,6 +48,8 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func): def f(*args, **kwargs): + devices.torch_gc() + shared.state.sampling_step = 0 shared.state.job_count = -1 shared.state.job_no = 0 @@ -59,6 +65,8 @@ def wrap_gradio_gpu_call(func): shared.state.job = "" shared.state.job_count = 0 + devices.torch_gc() + return res return modules.ui.wrap_gradio_call(f)