From 7acfaca05a13352a7d86d281db6ad64dfd9350e0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 00:59:44 +0300 Subject: [PATCH] update lists of models after merging them in checkpoints tab support saving as half --- modules/extras.py | 27 +++++++++++++++++---------- modules/sd_models.py | 15 ++++++++++----- modules/shared.py | 2 +- modules/ui.py | 42 ++++++++++++++++++++++++------------------ 4 files changed, 52 insertions(+), 34 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index dcc0148c..9a825530 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -13,6 +13,7 @@ from modules.ui import plaintext_to_html import modules.codeformer_model import piexif import piexif.helper +import gradio as gr cached_images = {} @@ -140,7 +141,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount): +def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half): # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -156,14 +157,14 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) return theta0 + ((theta1 - theta0) * alpha) - primary_model_filename = sd_models.checkpoints_list[primary_model_name].filename - secondary_model_filename = sd_models.checkpoints_list[secondary_model_name].filename + primary_model_info = sd_models.checkpoints_list[primary_model_name] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - print(f"Loading {primary_model_filename}...") - primary_model = torch.load(primary_model_filename, map_location='cpu') + print(f"Loading {primary_model_info.filename}...") + primary_model = torch.load(primary_model_info.filename, map_location='cpu') - print(f"Loading {secondary_model_filename}...") - secondary_model = torch.load(secondary_model_filename, map_location='cpu') + print(f"Loading {secondary_model_info.filename}...") + secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') theta_0 = primary_model['state_dict'] theta_1 = secondary_model['state_dict'] @@ -178,17 +179,23 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int print(f"Merging...") for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + if save_as_half: + theta_0[key] = theta_0[key].half() for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] + if save_as_half: + theta_0[key] = theta_0[key].half() - filename = primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) print(f"Saving to {output_modelname}...") torch.save(primary_model, output_modelname) + sd_models.list_models() + print(f"Checkpoint saved.") - return "Checkpoint saved to " + output_modelname + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] diff --git a/modules/sd_models.py b/modules/sd_models.py index 9decc911..dd47dffb 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -23,6 +23,11 @@ except Exception: pass +def checkpoint_tiles(): + print(sorted([x.title for x in checkpoints_list.values()])) + return sorted([x.title for x in checkpoints_list.values()]) + + def list_models(): checkpoints_list.clear() @@ -39,13 +44,14 @@ def list_models(): if name.startswith("\\") or name.startswith("/"): name = name[1:] - return f'{name} [{h}]' + shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + + return f'{name} [{h}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) - title = modeltitle(cmd_ckpt, h) - model_name = title.rsplit(".",1)[0] # remove extension if present + title, model_name = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) @@ -53,8 +59,7 @@ def list_models(): if os.path.exists(model_dir): for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): h = model_hash(filename) - title = modeltitle(filename, h) - model_name = title.rsplit(".",1)[0] # remove extension if present + title, model_name = modeltitle(filename, h) checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) diff --git a/modules/shared.py b/modules/shared.py index 39cf89bc..ec1e569b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -190,7 +190,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/modules/ui.py b/modules/ui.py index d51f7a08..4958036a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -872,29 +872,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): gr.HTML(value="

A merger of the two checkpoints will be generated in your /models directory.

") with gr.Row(): - ckpt_name_list = sorted([x.title for x in modules.sd_models.checkpoints_list.values()]) - primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name") - secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") - submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + save_as_half = gr.Checkbox(value=False, label="Safe as float16") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - submit.click( - fn=run_modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - interp_method, - interp_amount - ], - outputs=[ - submit_result, - ] - ) - def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -918,6 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): return comp(label=info.label, value=fun, **(args or {})) components = [] + component_dict = {} def run_settings(*args): changed = 0 @@ -973,7 +961,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - components.append(create_setting_component(k)) + component = create_setting_component(k) + component_dict[k] = component + components.append(component) items_displayed += 1 request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") @@ -1024,6 +1014,22 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): outputs=[result, text_settings], ) + modelmerger_merge.click( + fn=run_modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + interp_method, + interp_amount, + save_as_half, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Seed', 'Size-1', 'Size-2'] txt2img_fields = [field for field,name in txt2img_paste_fields if name in paste_field_names] img2img_fields = [field for field,name in img2img_paste_fields if name in paste_field_names]