From ad02b249f5bf8e494c35a313f44515b7b1e6739d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 8 Sep 2022 15:49:47 +0300 Subject: [PATCH] add a helpful message when user puts RealESRGAN model into ESRGAN directory. --- modules/esrgan_model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2ed1d273..e86ad775 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -14,17 +14,20 @@ import modules.images def load_model(filename): # this code is adapted from https://github.com/xinntao/ESRGAN - if torch.has_mps: - map_l = 'cpu' - else: - map_l = None - pretrained_net = torch.load(filename, map_location=map_l) + pretrained_net = torch.load(filename, map_location='cpu' if torch.has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) if 'conv_first.weight' in pretrained_net: crt_model.load_state_dict(pretrained_net) return crt_model + if 'model.0.weight' not in pretrained_net: + is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] + if is_realesrgan: + raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") + else: + raise Exception("The file is not a ESRGAN model.") + crt_net = crt_model.state_dict() load_net_clean = {} for k, v in pretrained_net.items():