Upscaler.load_model: don't return None, just use exceptions
This commit is contained in:
parent
e3a973a68d
commit
bf67a5dcf4
@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler):
|
||||
|
||||
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||
|
||||
try:
|
||||
return LDSR(model, yaml)
|
||||
except Exception:
|
||||
errors.report("Error importing LDSR", exc_info=True)
|
||||
return None
|
||||
|
||||
def do_upscale(self, img, path):
|
||||
try:
|
||||
ldsr = self.load_model(path)
|
||||
if ldsr is None:
|
||||
print("NO LDSR!")
|
||||
except Exception:
|
||||
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||
return img
|
||||
ddim_steps = shared.opts.ldsr_steps
|
||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
import PIL.Image
|
||||
@ -8,7 +7,7 @@ from tqdm import tqdm
|
||||
|
||||
import modules.upscaler
|
||||
from modules import devices, modelloader, script_callbacks, errors
|
||||
from scunet_model_arch import SCUNet as net
|
||||
from scunet_model_arch import SCUNet
|
||||
|
||||
from modules.modelloader import load_file_from_url
|
||||
from modules.shared import opts
|
||||
@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
model = self.load_model(selected_file)
|
||||
if model is None:
|
||||
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
|
||||
device = devices.get_device_for('scunet')
|
||||
@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
||||
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||
model.load_state_dict(torch.load(filename), strict=True)
|
||||
model.eval()
|
||||
for _, v in model.named_parameters():
|
||||
|
@ -1,4 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -7,8 +7,8 @@ from tqdm import tqdm
|
||||
|
||||
from modules import modelloader, devices, script_callbacks, shared
|
||||
from modules.shared import opts, state
|
||||
from swinir_model_arch import SwinIR as net
|
||||
from swinir_model_arch_v2 import Swin2SR as net2
|
||||
from swinir_model_arch import SwinIR
|
||||
from swinir_model_arch_v2 import Swin2SR
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler):
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img, model_file):
|
||||
try:
|
||||
model = self.load_model(model_file)
|
||||
if model is None:
|
||||
except Exception as e:
|
||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
img = upscale(img, model)
|
||||
@ -56,10 +58,8 @@ class UpscalerSwinIR(Upscaler):
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if filename is None or not os.path.exists(filename):
|
||||
return None
|
||||
if filename.endswith(".v2.pth"):
|
||||
model = net2(
|
||||
model = Swin2SR(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
@ -74,7 +74,7 @@ class UpscalerSwinIR(Upscaler):
|
||||
)
|
||||
params = None
|
||||
else:
|
||||
model = net(
|
||||
model = SwinIR(
|
||||
upscale=scale,
|
||||
in_chans=3,
|
||||
img_size=64,
|
||||
|
@ -1,4 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -6,9 +6,8 @@ from PIL import Image
|
||||
|
||||
import modules.esrgan_model_arch as arch
|
||||
from modules import modelloader, images, devices
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.shared import opts
|
||||
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
|
||||
|
||||
def mod2normal(state_dict):
|
||||
@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler):
|
||||
self.scalers.append(scaler_data)
|
||||
|
||||
def do_upscale(self, img, selected_model):
|
||||
try:
|
||||
model = self.load_model(selected_model)
|
||||
if model is None:
|
||||
except Exception as e:
|
||||
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model.to(devices.device_esrgan)
|
||||
img = esrgan_upscale(model, img)
|
||||
@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler):
|
||||
)
|
||||
else:
|
||||
filename = path
|
||||
if not os.path.exists(filename) or filename is None:
|
||||
print(f"Unable to load {self.model_path} from {filename}")
|
||||
return None
|
||||
|
||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||
|
||||
|
@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts
|
||||
from modules import modelloader, errors
|
||||
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
def __init__(self, path):
|
||||
self.name = "RealESRGAN"
|
||||
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
if not self.enable:
|
||||
return img
|
||||
|
||||
try:
|
||||
info = self.load_model(path)
|
||||
if not os.path.exists(info.local_data_path):
|
||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||
except Exception:
|
||||
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||
return img
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
||||
return image
|
||||
|
||||
def load_model(self, path):
|
||||
try:
|
||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
||||
|
||||
if info is None:
|
||||
print(f"Unable to find model info: {path}")
|
||||
return None
|
||||
|
||||
if info.local_data_path.startswith("http"):
|
||||
info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path)
|
||||
|
||||
return info
|
||||
except Exception:
|
||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
||||
return None
|
||||
for scaler in self.scalers:
|
||||
if scaler.data_path == path:
|
||||
if scaler.local_data_path.startswith("http"):
|
||||
scaler.local_data_path = modelloader.load_file_from_url(
|
||||
scaler.data_path,
|
||||
model_dir=self.model_download_path,
|
||||
)
|
||||
if not os.path.exists(scaler.local_data_path):
|
||||
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||
return scaler
|
||||
raise ValueError(f"Unable to find model info: {path}")
|
||||
|
||||
def load_models(self, _):
|
||||
return get_realesrgan_models(self)
|
||||
|
Loading…
Reference in New Issue
Block a user