fix: handles when state_dict does not exist

This commit is contained in:
leko 2022-10-07 23:09:21 +08:00 committed by AUTOMATIC1111
parent 87db6f01cc
commit 616b7218f7

View File

@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
pl_sd = torch.load(checkpoint_file, map_location="cpu") pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
else:
sd = pl_sd
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)