add a helpful message when user puts RealESRGAN model into ESRGAN directory.

This commit is contained in:
AUTOMATIC 2022-09-08 15:49:47 +03:00
parent 62ce77e245
commit ad02b249f5

View File

@ -14,17 +14,20 @@ import modules.images
def load_model(filename): def load_model(filename):
# this code is adapted from https://github.com/xinntao/ESRGAN # this code is adapted from https://github.com/xinntao/ESRGAN
if torch.has_mps: pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None)
map_l = 'cpu'
else:
map_l = None
pretrained_net = torch.load(filename, map_location=map_l)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net: if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net) crt_model.load_state_dict(pretrained_net)
return crt_model return crt_model
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
if is_realesrgan:
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
else:
raise Exception("The file is not a ESRGAN model.")
crt_net = crt_model.state_dict() crt_net = crt_model.state_dict()
load_net_clean = {} load_net_clean = {}
for k, v in pretrained_net.items(): for k, v in pretrained_net.items():