From bf67a5dcf44c3dbd88d1913478d4e02477915f33 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:38:51 +0300 Subject: [PATCH] Upscaler.load_model: don't return None, just use exceptions --- extensions-builtin/LDSR/scripts/ldsr_model.py | 13 +++--- .../ScuNET/scripts/scunet_model.py | 16 +++----- .../SwinIR/scripts/swinir_model.py | 40 +++++++++---------- modules/esrgan_model.py | 14 +++---- modules/realesrgan_model.py | 33 +++++++-------- 5 files changed, 52 insertions(+), 64 deletions(-) diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index bf9b6de2..bd78dece 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler): yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") - try: - return LDSR(model, yaml) - except Exception: - errors.report("Error importing LDSR", exc_info=True) - return None + return LDSR(model, yaml) def do_upscale(self, img, path): - ldsr = self.load_model(path) - if ldsr is None: - print("NO LDSR!") + try: + ldsr = self.load_model(path) + except Exception: + errors.report(f"Failed loading LDSR model {path}", exc_info=True) return img ddim_steps = shared.opts.ldsr_steps return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index da74a829..ffef26b2 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,4 +1,3 @@ -import os.path import sys import PIL.Image @@ -8,7 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet as net +from scunet_model_arch import SCUNet from modules.modelloader import load_file_from_url from modules.shared import opts @@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) + try: + model = self.load_model(selected_file) + except Exception as e: + print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img device = devices.get_device_for('scunet') @@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: - print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) - return None - - model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() for _, v in model.named_parameters(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 4551761d..3ce622d9 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -7,8 +7,8 @@ from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR as net -from swinir_model_arch_v2 import Swin2SR as net2 +from swinir_model_arch import SwinIR +from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData @@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img, model_file): - model = self.load_model(model_file) - if model is None: + try: + model = self.load_model(model_file) + except Exception as e: + print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img model = model.to(device_swinir, dtype=devices.dtype) img = upscale(img, model) @@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename is None or not os.path.exists(filename): - return None if filename.endswith(".v2.pth"): - model = net2( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", + model = Swin2SR( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="1conv", ) params = None else: - model = net( + model = SwinIR( upscale=scale, in_chans=3, img_size=64, diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a20e8d91..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -6,9 +6,8 @@ from PIL import Image import modules.esrgan_model_arch as arch from modules import modelloader, images, devices -from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts - +from modules.upscaler import Upscaler, UpscalerData def mod2normal(state_dict): @@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): - model = self.load_model(selected_model) - if model is None: + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) return img model.to(devices.device_esrgan) img = esrgan_upscale(model, img) @@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ) else: filename = path - if not os.path.exists(filename) or filename is None: - print(f"Unable to load {self.model_path} from {filename}") - return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0d9c2e48..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors - class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable: return img - info = self.load_model(path) - if not os.path.exists(info.local_data_path): - print(f"Unable to load RealESRGAN model: {info.name}") + try: + info = self.load_model(path) + except Exception: + errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img upsampler = RealESRGANer( @@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image def load_model(self, path): - try: - info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) - - if info is None: - print(f"Unable to find model info: {path}") - return None - - if info.local_data_path.startswith("http"): - info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) - - return info - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) - return None + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") def load_models(self, _): return get_realesrgan_models(self)