make main model loading and model merger use the same code

This commit is contained in:
AUTOMATIC 2022-10-09 10:23:31 +03:00
parent 050a6a798c
commit c77c89cc83
2 changed files with 12 additions and 8 deletions

View File

@ -170,8 +170,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = primary_model['state_dict']
theta_1 = secondary_model['state_dict']
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
theta_funcs = {
"Weighted Sum": weighted_sum,

View File

@ -122,6 +122,13 @@ def select_checkpoint():
return checkpoint_info
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
return pl_sd
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@ -132,10 +139,7 @@ def load_model_weights(model, checkpoint_info):
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False)