diff --git a/webui.py b/webui.py index 51ce02c7..55654f55 100644 --- a/webui.py +++ b/webui.py @@ -58,6 +58,7 @@ parser.add_argument("--grid-extended-filename", action='store_true', help="save parser.add_argument("--jpeg-quality", type=int, default=80, help="quality for saved jpeg images") parser.add_argument("--disable-pnginfo", action='store_true', help="disable saving text information about generation parameters as chunks to png files") +parser.add_argument("--inversion", action='store_true', help="switch to stable inversion version; allows for uploading embeddings; this option should be used only with textual inversion repo") opt = parser.parse_args() GFPGAN_dir = opt.gfpgan_dir @@ -189,8 +190,8 @@ if os.path.exists(GFPGAN_dir): print("Error loading GFPGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) -config = OmegaConf.load("configs/stable-diffusion/v1-inference.yaml") -model = load_model_from_config(config, "models/ldm/stable-diffusion-v1/model.ckpt") +config = OmegaConf.load(opt.config) +model = load_model_from_config(config, opt.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = (model if opt.no_half else model.half()).to(device) @@ -467,9 +468,17 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, return output_images, seed, info -def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int): +def load_embeddings(fp): + if fp is not None and hasattr(model, "embedding_manager"): + # load the file + model.embedding_manager.load(fp.name) + + +def txt2img(prompt: str, ddim_steps: int, sampler_name: str, use_GFPGAN: bool, prompt_matrix: bool, ddim_eta: float, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, embeddings_fp): outpath = opt.outdir or "outputs/txt2img-samples" + load_embeddings(embeddings_fp) + if sampler_name == 'PLMS': sampler = PLMSSampler(model) elif sampler_name == 'DDIM': @@ -564,6 +573,7 @@ txt2img_interface = gr.Interface( gr.Number(label='Seed', value=-1), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), + gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion) ], outputs=[ gr.Gallery(label="Images"), @@ -576,9 +586,11 @@ txt2img_interface = gr.Interface( ) -def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int): +def img2img(prompt: str, init_img, ddim_steps: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, embeddings_fp): outpath = opt.outdir or "outputs/img2img-samples" + load_embeddings(embeddings_fp) + sampler = KDiffusionSampler(model) assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' @@ -693,7 +705,8 @@ img2img_interface = gr.Interface( gr.Number(label='Seed', value=-1), gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512), gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512), - gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize") + gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize"), + gr.File(label = "Embeddings file for textual inversion", visible=opt.inversion) ], outputs=[ gr.Gallery(), @@ -739,6 +752,7 @@ if GFPGAN is not None: allow_flagging="never", ), "GFPGAN")) + demo = gr.TabbedInterface( interface_list=[x[0] for x in interfaces], tab_names=[x[1] for x in interfaces], @@ -748,4 +762,4 @@ demo = gr.TabbedInterface( """ ) -demo.launch() +demo.launch() \ No newline at end of file