diff --git a/javascript/hints.js b/javascript/hints.js index fa5e5ae8..e746e20d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "No interpolation": "Result = A", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", diff --git a/javascript/ui.js b/javascript/ui.js index 428375d4..37788a3e 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -176,8 +176,6 @@ function modelmerger(){ var id = randomId() requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) - gradioApp().getElementById('modelmerger_result').innerHTML = '' - var res = create_submit_args(arguments) res[0] = id return res diff --git a/modules/extras.py b/modules/extras.py index 034f28e4..fe701a0e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - return sd_models.find_checkpoint_config(x) if x else None + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None if config_source == 0: cfg = config(a) or config(b) or config(c) @@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' - shared.state.job_count = 1 def fail(message): shared.state.textinfo = message @@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + def filename_weighed_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_differnece(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighed_sum, None, weighted_sum), + "Add difference": (filename_add_differnece, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + if not primary_model_name: return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] - if not secondary_model_name: + if theta_func2 and not secondary_model_name: return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - theta_funcs = { - "Weighted sum": (None, weighted_sum), - "Add difference": (get_difference, add_difference), - } - theta_func1, theta_func2 = theta_funcs[interp_method] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None if theta_func1 and not tertiary_model_name: return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - shared.state.textinfo = f"Loading {secondary_model_info.filename}..." - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None if theta_func1: - shared.state.job_count += 1 - + shared.state.textinfo = f"Loading C" print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.textinfo = 'Merging B and C' shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): + if key in chckpoint_dict_skip_on_merge: + continue + if 'model' in key: if key in theta_2: t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) @@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') print("Merging...") - - chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - + shared.state.textinfo = 'Merging A and B' shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): - if 'model' in key and key in theta_1: + if theta_1 and 'model' in key and key in theta_1: if key in chckpoint_dict_skip_on_merge: continue @@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ a = theta_0[key] b = theta_1[key] - shared.state.textinfo = f'Merging layer {key}' # this enables merging an inpainting model (A) with another one (B); # where normal model would have 4 channels, for latenst space, inpainting model would # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 @@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.sampling_step += 1 - # I believe this part should be discarded, but I'll leave it for now until I am sure - for key in theta_1.keys(): - if 'model' in key and key not in theta_0: - - if key in chckpoint_dict_skip_on_merge: - continue - - theta_0[key] = theta_1[key] - if save_as_half: - theta_0[key] = theta_0[key].half() del theta_1 + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] + + del vae_dict + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = \ - primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \ - secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \ - interp_method.replace(" ", "_") + \ - '-merged.' + \ - ("inpainting." if result_is_inpainting_model else "") + \ - checkpoint_format - - filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) shared.state.nextjob() - shared.state.textinfo = f"Saving to {output_modelname}..." + shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") _, extension = os.path.splitext(output_modelname) @@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - print("Checkpoint saved.") - shared.state.textinfo = "Checkpoint saved to " + output_modelname + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" shared.state.end() return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/sd_vae.py b/modules/sd_vae.py index da1bf15c..4ce238b8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file): return None, None +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + def load_vae(model, vae_file=None, vae_source="from unknown source"): global vae_dict, loaded_vae_file # save_settings = False @@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) - vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) _load_vae_dict(model, vae_dict_1) if cache_enabled: diff --git a/modules/shared.py b/modules/shared.py index 77e5e91c..29b28bff 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path demo = None +sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") diff --git a/modules/ui.py b/modules/ui.py index aeee7853..4e381a49 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -1185,7 +1185,7 @@ def create_ui(): with gr.Column(variant='compact'): gr.HTML(value="
A merger of the two checkpoints will be generated in your checkpoint directory.
") - with FormRow(): + with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1197,13 +1197,20 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + with FormRow(): + with gr.Column(): + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1757,6 +1764,7 @@ def create_ui(): return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results + modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) modelmerger_merge.click( fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), _js='modelmerger', @@ -1771,6 +1779,7 @@ def create_ui(): custom_name, checkpoint_format, config_source, + bake_in_vae, ], outputs=[ primary_model_name, diff --git a/style.css b/style.css index 32ba4753..c10e32a1 100644 --- a/style.css +++ b/style.css @@ -641,6 +641,16 @@ canvas[key="mask"] { margin: 0.6em 0em 0.55em 0; } +#modelmerger_results_container{ + margin-top: 1em; + overflow: visible; +} + +#modelmerger_models{ + gap: 0; +} + + #quicksettings .gr-button-tool{ margin: 0; } @@ -737,11 +747,6 @@ footer { line-height: 2.4em; } -#modelmerger_results_container{ - margin-top: 1em; - overflow: visible; -} - /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running