From 791808c890fc2fc3417f827f8744765970b23f13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 00:21:54 +0300 Subject: [PATCH] correctly list and display model names for #1261 --- modules/extras.py | 23 ++++++++--------------- modules/ui.py | 4 ++-- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 15de033a..dcc0148c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -6,7 +6,7 @@ from PIL import Image import torch import tqdm -from modules import processing, shared, images, devices +from modules import processing, shared, images, devices, sd_models from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -156,17 +156,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) return theta0 + ((theta1 - theta0) * alpha) - if os.path.exists(primary_model_name): - primary_model_filename = primary_model_name - primary_model_name = os.path.splitext(os.path.basename(primary_model_name))[0] - else: - primary_model_filename = 'models/' + primary_model_name + '.ckpt' - - if os.path.exists(secondary_model_name): - secondary_model_filename = secondary_model_name - secondary_model_name = os.path.splitext(os.path.basename(secondary_model_name))[0] - else: - secondary_model_filename = 'models/' + secondary_model_name + '.ckpt' + primary_model_filename = sd_models.checkpoints_list[primary_model_name].filename + secondary_model_filename = sd_models.checkpoints_list[secondary_model_name].filename print(f"Loading {primary_model_filename}...") primary_model = torch.load(primary_model_filename, map_location='cpu') @@ -180,7 +171,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int theta_funcs = { "Weighted Sum": weighted_sum, "Sigmoid": sigmoid, - "Inverse Sigmoid": inv_sigmoid + "Inverse Sigmoid": inv_sigmoid, } theta_func = theta_funcs[interp_method] @@ -193,9 +184,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] - output_modelname = 'models/' + primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = primary_model_name + '_' + str(round(interp_amount,2)) + '-' + secondary_model_name + '_' + str(round((float(1.0) - interp_amount),2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) + print(f"Saving to {output_modelname}...") torch.save(primary_model, output_modelname) print(f"Checkpoint saved.") - return "Checkpoint saved to " + output_modelname \ No newline at end of file + return "Checkpoint saved to " + output_modelname diff --git a/modules/ui.py b/modules/ui.py index bf736b27..d51f7a08 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -872,8 +872,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): gr.HTML(value="

A merger of the two checkpoints will be generated in your /models directory.

") with gr.Row(): - ckpt_name_list = sorted([x.model_name for x in modules.sd_models.checkpoints_list.values()]) - primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name") + ckpt_name_list = sorted([x.title for x in modules.sd_models.checkpoints_list.values()]) + primary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(ckpt_name_list, elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")