textual inversion embeddings support

settings tab
This commit is contained in:
AUTOMATIC 2022-08-25 21:52:05 +03:00
parent ec8a252260
commit 91dc8710ec
3 changed files with 356 additions and 107 deletions

View File

@ -1,7 +1,7 @@
# Stable Diffusion web UI # Stable Diffusion web UI
A browser interface based on Gradio library for Stable Diffusion. A browser interface based on Gradio library for Stable Diffusion.
Original script with Gradio UI was written by a kind anonymopus user. This is a modification. Original script with Gradio UI was written by a kind anonymous user. This is a modification.
![](screenshot.png) ![](screenshot.png)
## Installing and running ## Installing and running
@ -128,7 +128,7 @@ Example:
Gradio's loading graphic has a very negative effect on the processing speed of the neural network. Gradio's loading graphic has a very negative effect on the processing speed of the neural network.
My RTX 3090 makes images about 10% faster when the tab with gradio is not active. By default, the UI My RTX 3090 makes images about 10% faster when the tab with gradio is not active. By default, the UI
now hides loading progress animation and replaces it with static "Loading..." text, which achieves now hides loading progress animation and replaces it with static "Loading..." text, which achieves
the same effect. Use the --no-progressbar-hiding commandline option to revert this and show loading animations. the same effect. Use the `--no-progressbar-hiding` commandline option to revert this and show loading animations.
### Prompt validation ### Prompt validation
Stable Diffusion has a limit for input text length. If your prompt is too long, you will get a Stable Diffusion has a limit for input text length. If your prompt is too long, you will get a
@ -152,6 +152,28 @@ Adds information about generation parameters to PNG as a text chunk. You
can view this information later using any software that supports viewing can view this information later using any software that supports viewing
PNG chunk info, for example: https://www.nayuki.io/page/png-file-chunk-inspector PNG chunk info, for example: https://www.nayuki.io/page/png-file-chunk-inspector
This can be disabled using the `--disable-pnginfo` command line option.
![](images/pnginfo.png) ![](images/pnginfo.png)
### Textual Inversion
Allows you to use pretrained textual inversion embeddings.
See originial site for details: https://textual-inversion.github.io/.
I used lstein's repo for training embdedding: https://github.com/lstein/stable-diffusion; if
you want to train your own, I recommend following the guide on his site.
No additional libraries/repositories are required to use pretrained embeddings.
To make use of pretrained embeddings, create `embeddings` directory in the root dir of Stable
Diffusion and put your embeddings into it. They must be .pt files about 5Kb in size, each with only
one trained embedding, and the filename (without .pt) will be the term you'd use in prompt
to get that embedding.
As an example, I trained one for about 5000 steps: https://files.catbox.moe/e2ui6r.pt; it does
not produce very good results, but it does work. Download and rename it to `Usada Pekora.pt`,
and put it into `embeddings` dir and use Usada Pekora in prompt.
![](images/inversion.png)
### Settings
A tab with settings, allowing you to use UI to edit more than half of parameters that previously
were commandline. Settings are saved to config.js file. Settings that remain as commandline
options are ones that are required at startup.

BIN
images/inversion.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 678 KiB

397
webui.py
View File

@ -8,17 +8,19 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from itertools import islice from itertools import islice
from einops import rearrange, repeat from einops import rearrange, repeat
from torch import autocast from torch import autocast
from contextlib import contextmanager, nullcontext
import mimetypes import mimetypes
import random import random
import math import math
import html import html
import time import time
import json
import traceback
import k_diffusion as K import k_diffusion as K
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
import ldm.modules.encoders.modules
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@ -38,30 +40,18 @@ opt_f = 8
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
invalid_filename_chars = '<>:"/\|?*\n' invalid_filename_chars = '<>:"/\|?*\n'
config_filename = "config.json"
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--outdir", type=str, nargs="?", help="dir to write results to", default=None)
parser.add_argument("--skip_grid", action='store_true', help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",)
parser.add_argument("--skip_save", action='store_true', help="do not save indiviual samples. For speed measurements.",)
parser.add_argument("--n_rows", type=int, default=-1, help="rows in the grid; use -1 for autodetect and 0 for n_rows to be same as batch_size (default: -1)",)
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",) parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",) parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) # i disagree with where you're putting it but since all guidefags are doing it this way, there you go
parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--save-format", type=str, default='png', help="file format for saved indiviual samples; can be png or jpg") parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--grid-format", type=str, default='png', help="file format for saved grids; can be png or jpg")
parser.add_argument("--grid-extended-filename", action='store_true', help="save grid images to filenames with extended info: seed, prompt")
parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images")
parser.add_argument("--disable-pnginfo", action='store_true', help="disable saving text information about generation parameters as chunks to png files")
parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo") cmd_opts = parser.parse_args()
opt = parser.parse_args()
GFPGAN_dir = opt.gfpgan_dir
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
@ -70,6 +60,49 @@ css_hide_progressbar = """
.meta-text { display:none!important; } .meta-text { display:none!important; }
""" """
class Options:
data = None
data_labels = {
"outdir": ("", "Output dictectory; if empty, defaults to 'outputs/*'"),
"samples_save": (True, "Save indiviual samples"),
"samples_format": ('png', 'File format for indiviual samples'),
"grid_save": (True, "Save image grids"),
"grid_format": ('png', 'File format for grids'),
"grid_extended_filename": (False, "Add extended info (seed, prompt) to filename when saving grid"),
"n_rows": (-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", -1, 16),
"jpeg_quality": (80, "Quality for saved jpeg images", 1, 100),
"verify_input": (True, "Check input, and produce warning if it's too long"),
"enable_pnginfo": (True, "Save text information about generation parameters as chunks to png files"),
}
def __init__(self):
self.data = {k: v[0] for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data:
self.data[key] = value
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if self.data is not None:
if item in self.data:
return self.data[item]
return super(Options, self).__getattribute__(item)
def save(self, filename):
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file)
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
def chunk(it, size): def chunk(it, size):
it = iter(it) it = iter(it)
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
@ -154,13 +187,13 @@ def save_image(image, path, basename, seed, prompt, extension, info=None, short_
else: else:
filename = f"{basename}-{seed}-{prompt[:128]}.{extension}" filename = f"{basename}-{seed}-{prompt[:128]}.{extension}"
if extension == 'png' and not opt.disable_pnginfo: if extension == 'png' and opts.enable_pnginfo:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text("parameters", info) pnginfo.add_text("parameters", info)
else: else:
pnginfo = None pnginfo = None
image.save(os.path.join(path, filename), quality=opt.jpeg_quality, pnginfo=pnginfo) image.save(os.path.join(path, filename), quality=opts.jpeg_quality, pnginfo=pnginfo)
def plaintext_to_html(text): def plaintext_to_html(text):
@ -170,39 +203,22 @@ def plaintext_to_html(text):
def load_GFPGAN(): def load_GFPGAN():
model_name = 'GFPGANv1.3' model_name = 'GFPGANv1.3'
model_path = os.path.join(GFPGAN_dir, 'experiments/pretrained_models', model_name + '.pth') model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path) raise Exception("GFPGAN model not found at path "+model_path)
sys.path.append(os.path.abspath(GFPGAN_dir)) sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
from gfpgan import GFPGANer from gfpgan import GFPGANer
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
GFPGAN = None
if os.path.exists(GFPGAN_dir):
try:
GFPGAN = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
import traceback
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
config = OmegaConf.load(opt.config)
model = load_model_from_config(config, opt.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if opt.no_half else model.half()).to(device)
def image_grid(imgs, batch_size, round_down=False, force_n_rows=None): def image_grid(imgs, batch_size, round_down=False, force_n_rows=None):
if force_n_rows is not None: if force_n_rows is not None:
rows = force_n_rows rows = force_n_rows
elif opt.n_rows > 0: elif opts.n_rows > 0:
rows = opt.n_rows rows = opts.n_rows
elif opt.n_rows == 0: elif opts.n_rows == 0:
rows = batch_size rows = batch_size
else: else:
rows = math.sqrt(len(imgs)) rows = math.sqrt(len(imgs))
@ -353,6 +369,163 @@ def wrap_gradio_call(func):
return f return f
GFPGAN = None
if os.path.exists(cmd_opts.gfpgan_dir):
try:
GFPGAN = load_GFPGAN()
print("Loaded GFPGAN")
except Exception:
print("Error loading GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
class TextInversionEmbeddings:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = []
used_custom_terms = []
dir_mtime = None
def load(self, dir, model):
mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path)
param_dict = data['string_to_param']
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1].reshape(768)
self.word_embeddings[name] = emb
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dir):
try:
process_file(os.path.join(dir, fn), fn)
except:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
def forward(self, text):
self.embeddings.fixes = []
self.embeddings.used_custom_terms = []
remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.embeddings.ids_lookup.get(token, None)
if possible_matches is None:
remade_tokens.append(token)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes)
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
class EmbeddingsWithFixes(nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = []
inputs_embeds = self.wrapped(input_ids)
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
tensor[offset] = self.embeddings.word_embeddings[word]
return inputs_embeds
def get_learned_conditioning_with_embeddings(model, prompts):
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
return model.get_learned_conditioning(prompts)
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False): def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@ -392,7 +565,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.") print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
else: else:
if not opt.no_verify_input: if opts.verify_input:
try: try:
check_prompt_length(prompt, comments) check_prompt_length(prompt, comments)
except: except:
@ -403,27 +576,29 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
all_prompts = batch_size * n_iter * [prompt] all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))] all_seeds = [seed + x for x in range(len(all_prompts))]
info = f""" def infotext():
return f"""
{prompt} {prompt}
Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''} Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip() + "".join(["\n\n" + x for x in comments]) """.strip() + "".join(["\n\n" + x for x in comments])
precision_scope = autocast if opt.precision == "autocast" else nullcontext if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
output_images = [] output_images = []
with torch.no_grad(), precision_scope("cuda"), model.ema_scope(): with torch.no_grad(), autocast("cuda"), model.ema_scope():
init_data = func_init() init_data = func_init()
for n in range(n_iter): for n in range(n_iter):
prompts = all_prompts[n * batch_size:(n + 1) * batch_size] prompts = all_prompts[n * batch_size:(n + 1) * batch_size]
seeds = all_seeds[n * batch_size:(n + 1) * batch_size] seeds = all_seeds[n * batch_size:(n + 1) * batch_size]
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(len(prompts) * [""]) uc = model.get_learned_conditioning(len(prompts) * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts) c = model.get_learned_conditioning(prompts)
if len(text_inversion_embeddings.used_custom_terms) > 0:
comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed # we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
@ -432,7 +607,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if prompt_matrix or not opt.skip_save or not opt.skip_grid: if prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim): for i, x_sample in enumerate(x_samples_ddim):
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -442,12 +617,12 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
x_sample = restored_img x_sample = restored_img
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opt.save_format, info=info) save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
output_images.append(image) output_images.append(image)
base_count += 1 base_count += 1
if (prompt_matrix or not opt.skip_grid) and not do_not_save_grid: if (prompt_matrix or opts.grid_save) and not do_not_save_grid:
grid = image_grid(output_images, batch_size, round_down=prompt_matrix) grid = image_grid(output_images, batch_size, round_down=prompt_matrix)
if prompt_matrix: if prompt_matrix:
@ -461,23 +636,17 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name,
output_images.insert(0, grid) output_images.insert(0, grid)
save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename) save_image(grid, outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
grid_count += 1 grid_count += 1
torch_gc() torch_gc()
return output_images, seed, info return output_images, seed, infotext()
def load_embeddings(fp): def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int):
if fp is not None and hasattr(model, "embedding_manager"): outpath = opts.outdir or "outputs/txt2img-samples"
# load the file
model.embedding_manager.load(fp.name)
def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, embeddings_fp):
outpath = opt.outdir or "outputs/txt2img-samples"
load_embeddings(embeddings_fp)
if sampler_name == 'PLMS': if sampler_name == 'PLMS':
sampler = PLMSSampler(model) sampler = PLMSSampler(model)
@ -567,29 +736,25 @@ txt2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Number(label='Seed', value=-1), gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
], ],
outputs=[ outputs=[
gr.Gallery(label="Images"), gr.Gallery(label="Images"),
gr.Number(label='Seed'), gr.Number(label='Seed'),
gr.HTML(), gr.HTML(),
], ],
title="Stable Diffusion Text-to-Image K", title="Stable Diffusion Text-to-Image",
description="Generate images from text with Stable Diffusion (using K-LMS)",
flagging_callback=Flagging() flagging_callback=Flagging()
) )
def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, embeddings_fp): def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opt.outdir or "outputs/img2img-samples" outpath = opts.outdir or "outputs/img2img-samples"
load_embeddings(embeddings_fp)
sampler = KDiffusionSampler(model) sampler = KDiffusionSampler(model)
@ -658,7 +823,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_mat
grid_count = len(os.listdir(outpath)) - 1 grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, force_n_rows=1) grid = image_grid(history, batch_size, force_n_rows=1)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opt.grid_format, info=info, short_filename=not opt.grid_extended_filename) save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
output_images = history output_images = history
seed = initial_seed seed = initial_seed
@ -698,15 +863,14 @@ img2img_interface = gr.Interface(
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None), gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=GFPGAN is not None),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False), gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False), gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Slider(minimum=1, maximum=opt.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1), gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1), gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0), gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75), gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1), gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"), gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion)
], ],
outputs=[ outputs=[
gr.Gallery(), gr.Gallery(),
@ -714,15 +878,9 @@ img2img_interface = gr.Interface(
gr.HTML(), gr.HTML(),
], ],
title="Stable Diffusion Image-to-Image", title="Stable Diffusion Image-to-Image",
description="Generate images from images with Stable Diffusion",
allow_flagging="never", allow_flagging="never",
) )
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img")
]
def run_GFPGAN(image, strength): def run_GFPGAN(image, strength):
image = image.convert("RGB") image = image.convert("RGB")
@ -735,8 +893,7 @@ def run_GFPGAN(image, strength):
return res, 0, '' return res, 0, ''
if GFPGAN is not None: gfpgan_interface = gr.Interface(
interfaces.append((gr.Interface(
run_GFPGAN, run_GFPGAN,
inputs=[ inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"), gr.Image(label="Source", source="upload", interactive=True, type="pil"),
@ -750,13 +907,83 @@ if GFPGAN is not None:
title="GFPGAN", title="GFPGAN",
description="Fix faces on images", description="Fix faces on images",
allow_flagging="never", allow_flagging="never",
), "GFPGAN")) )
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
def run_settings(*args):
up = []
for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
opts.data[key] = value
up.append(comp.update(value=value))
opts.save(config_filename)
return 'Settings saved.', ''
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key][0]
labelinfo = opts.data_labels[key]
t = type(labelinfo[0])
label = labelinfo[1]
if t == str:
item = gr.Textbox(label=label, value=fun, lines=1)
elif t == int:
if len(labelinfo) == 4:
item = gr.Slider(minimum=labelinfo[2], maximum=labelinfo[3], step=1, label=label, value=fun)
else:
item = gr.Number(label=label, value=fun)
elif t == bool:
item = gr.Checkbox(label=label, value=fun)
else:
raise Exception(f'bad options item type: {str(t)} for key {key}')
return item
settings_interface = gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never",
)
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img"),
(gfpgan_interface, "GFPGAN"),
(settings_interface, "Settings"),
]
config = OmegaConf.load(cmd_opts.config)
model = load_model_from_config(config, cmd_opts.ckpt)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
text_inversion_embeddings = TextInversionEmbeddings()
if os.path.exists(cmd_opts.embeddings_dir):
text_inversion_embeddings.hijack(model)
if GFPGAN is None:
interfaces = [x for x in interfaces if x[0] != gfpgan_interface]
demo = gr.TabbedInterface( demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces], interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces], tab_names=[x[1] for x in interfaces],
css=("" if opt.no_progressbar_hiding else css_hide_progressbar) + """ css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; } .performance { font-size: 0.85em; color: #444; }
""" """