Merge pull request #10823 from akx/model-loady
Upscaler model loading cleanup
This commit is contained in:
commit
3cd4fd51ef
@ -1,7 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from modules.modelloader import load_file_from_url
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks, errors
|
from modules import shared, script_callbacks, errors
|
||||||
@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
|
|||||||
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||||
model = local_safetensors_path
|
model = local_safetensors_path
|
||||||
else:
|
else:
|
||||||
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
|
model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")
|
||||||
|
|
||||||
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
|
yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")
|
||||||
|
|
||||||
try:
|
|
||||||
return LDSR(model, yaml)
|
return LDSR(model, yaml)
|
||||||
except Exception:
|
|
||||||
errors.report("Error importing LDSR", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def do_upscale(self, img, path):
|
def do_upscale(self, img, path):
|
||||||
|
try:
|
||||||
ldsr = self.load_model(path)
|
ldsr = self.load_model(path)
|
||||||
if ldsr is None:
|
except Exception:
|
||||||
print("NO LDSR!")
|
errors.report(f"Failed loading LDSR model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
ddim_steps = shared.opts.ldsr_steps
|
ddim_steps = shared.opts.ldsr_steps
|
||||||
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
return ldsr.super_resolution(img, ddim_steps, self.scale)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import os.path
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@ -6,12 +5,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader, script_callbacks, errors
|
from modules import devices, modelloader, script_callbacks, errors
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet
|
||||||
|
|
||||||
|
from modules.modelloader import load_file_from_url
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers = []
|
scalers = []
|
||||||
add_model2 = True
|
add_model2 = True
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -89,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
try:
|
||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
if model is None:
|
except Exception as e:
|
||||||
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
|
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
|
|
||||||
return None
|
|
||||||
|
|
||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for _, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
|
||||||
|
|
||||||
device_swinir = devices.get_device_for('swinir')
|
device_swinir = devices.get_device_for('swinir')
|
||||||
|
|
||||||
@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
|
|||||||
class UpscalerSwinIR(Upscaler):
|
class UpscalerSwinIR(Upscaler):
|
||||||
def __init__(self, dirname):
|
def __init__(self, dirname):
|
||||||
self.name = "SwinIR"
|
self.name = "SwinIR"
|
||||||
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
|
self.model_url = SWINIR_MODEL_URL
|
||||||
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
|
|
||||||
"-L_x4_GAN.pth "
|
|
||||||
self.model_name = "SwinIR 4x"
|
self.model_name = "SwinIR 4x"
|
||||||
self.user_path = dirname
|
self.user_path = dirname
|
||||||
super().__init__()
|
super().__init__()
|
||||||
scalers = []
|
scalers = []
|
||||||
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
model_files = self.find_models(ext_filter=[".pt", ".pth"])
|
||||||
for model in model_files:
|
for model in model_files:
|
||||||
if "http" in model:
|
if model.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(model)
|
name = modelloader.friendly_name(model)
|
||||||
@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img, model_file):
|
def do_upscale(self, img, model_file):
|
||||||
|
try:
|
||||||
model = self.load_model(model_file)
|
model = self.load_model(model_file)
|
||||||
if model is None:
|
except Exception as e:
|
||||||
|
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model = model.to(device_swinir, dtype=devices.dtype)
|
model = model.to(device_swinir, dtype=devices.dtype)
|
||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
@ -49,15 +49,16 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path, scale=4):
|
def load_model(self, path, scale=4):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
|
filename = modelloader.load_file_from_url(
|
||||||
filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
|
url=path,
|
||||||
|
model_dir=self.model_download_path,
|
||||||
|
file_name=f"{self.model_name.replace(' ', '_')}.pth",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if filename is None or not os.path.exists(filename):
|
|
||||||
return None
|
|
||||||
if filename.endswith(".v2.pth"):
|
if filename.endswith(".v2.pth"):
|
||||||
model = net2(
|
model = Swin2SR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
@ -72,7 +73,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
)
|
)
|
||||||
params = None
|
params = None
|
||||||
else:
|
else:
|
||||||
model = net(
|
model = SwinIR(
|
||||||
upscale=scale,
|
upscale=scale,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
img_size=64,
|
img_size=64,
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
import os
|
import sys
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
|
|
||||||
|
|
||||||
def mod2normal(state_dict):
|
def mod2normal(state_dict):
|
||||||
@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
|
||||||
scalers.append(scaler_data)
|
scalers.append(scaler_data)
|
||||||
for file in model_paths:
|
for file in model_paths:
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
name = self.model_name
|
name = self.model_name
|
||||||
else:
|
else:
|
||||||
name = modelloader.friendly_name(file)
|
name = modelloader.friendly_name(file)
|
||||||
@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
self.scalers.append(scaler_data)
|
self.scalers.append(scaler_data)
|
||||||
|
|
||||||
def do_upscale(self, img, selected_model):
|
def do_upscale(self, img, selected_model):
|
||||||
|
try:
|
||||||
model = self.load_model(selected_model)
|
model = self.load_model(selected_model)
|
||||||
if model is None:
|
except Exception as e:
|
||||||
|
print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
model.to(devices.device_esrgan)
|
model.to(devices.device_esrgan)
|
||||||
img = esrgan_upscale(model, img)
|
img = esrgan_upscale(model, img)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if path.startswith("http"):
|
||||||
filename = load_file_from_url(
|
# TODO: this doesn't use `path` at all?
|
||||||
|
filename = modelloader.load_file_from_url(
|
||||||
url=self.model_url,
|
url=self.model_url,
|
||||||
model_dir=self.model_download_path,
|
model_dir=self.model_download_path,
|
||||||
file_name=f"{self.model_name}.pth",
|
file_name=f"{self.model_name}.pth",
|
||||||
progress=True,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
|
||||||
print(f"Unable to load {self.model_path} from {filename}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def gfpgann():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
|
||||||
if len(models) == 1 and "http" in models[0]:
|
if len(models) == 1 and models[0].startswith("http"):
|
||||||
model_file = models[0]
|
model_file = models[0]
|
||||||
elif len(models) != 0:
|
elif len(models) != 0:
|
||||||
latest_file = max(models, key=os.path.getctime)
|
latest_file = max(models, key=os.path.getctime)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
|
|||||||
from modules.paths import script_path, models_path
|
from modules.paths import script_path, models_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_from_url(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
model_dir: str,
|
||||||
|
progress: bool = True,
|
||||||
|
file_name: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Download a file from `url` into `model_dir`, using the file present if possible.
|
||||||
|
|
||||||
|
Returns the path to the downloaded file.
|
||||||
|
"""
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
if not file_name:
|
||||||
|
parts = urlparse(url)
|
||||||
|
file_name = os.path.basename(parts.path)
|
||||||
|
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
|
||||||
|
if not os.path.exists(cached_file):
|
||||||
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
||||||
|
from torch.hub import download_url_to_file
|
||||||
|
download_url_to_file(url, cached_file, progress=progress)
|
||||||
|
return cached_file
|
||||||
|
|
||||||
|
|
||||||
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
|
||||||
"""
|
"""
|
||||||
A one-and done loader to try finding the desired models in specified directories.
|
A one-and done loader to try finding the desired models in specified directories.
|
||||||
@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
|
||||||
dl = load_file_from_url(model_url, places[0], True, download_name)
|
|
||||||
output.append(dl)
|
|
||||||
else:
|
else:
|
||||||
output.append(model_url)
|
output.append(model_url)
|
||||||
|
|
||||||
@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
|
|
||||||
|
|
||||||
def friendly_name(file: str):
|
def friendly_name(file: str):
|
||||||
if "http" in file:
|
if file.startswith("http"):
|
||||||
file = urlparse(file).path
|
file = urlparse(file).path
|
||||||
|
|
||||||
file = os.path.basename(file)
|
file = os.path.basename(file)
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
if not self.enable:
|
if not self.enable:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
try:
|
||||||
info = self.load_model(path)
|
info = self.load_model(path)
|
||||||
if not os.path.exists(info.local_data_path):
|
except Exception:
|
||||||
print(f"Unable to load RealESRGAN model: {info.name}")
|
errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
try:
|
for scaler in self.scalers:
|
||||||
info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
|
if scaler.data_path == path:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
if info is None:
|
scaler.local_data_path = modelloader.load_file_from_url(
|
||||||
print(f"Unable to find model info: {path}")
|
scaler.data_path,
|
||||||
return None
|
model_dir=self.model_download_path,
|
||||||
|
)
|
||||||
if info.local_data_path.startswith("http"):
|
if not os.path.exists(scaler.local_data_path):
|
||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
|
raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
|
||||||
|
return scaler
|
||||||
return info
|
raise ValueError(f"Unable to find model info: {path}")
|
||||||
except Exception:
|
|
||||||
errors.report("Error making Real-ESRGAN models list", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def load_models(self, _):
|
def load_models(self, _):
|
||||||
return get_realesrgan_models(self)
|
return get_realesrgan_models(self)
|
||||||
|
Loading…
Reference in New Issue
Block a user