add option SWIN_torch_compile to accelerate SwinIR upscale using torch.compile()

This commit is contained in:
SiYu Wu 2023-07-09 03:05:38 +08:00
parent 4da92281f6
commit 44d66daaad

View File

@ -1,4 +1,5 @@
import sys import sys
import platform
import numpy as np import numpy as np
import torch import torch
@ -18,6 +19,8 @@ device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
self.name = "SwinIR" self.name = "SwinIR"
self.model_url = SWINIR_MODEL_URL self.model_url = SWINIR_MODEL_URL
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
@ -35,12 +38,24 @@ class UpscalerSwinIR(Upscaler):
self.scalers = scalers self.scalers = scalers
def do_upscale(self, img, model_file): def do_upscale(self, img, model_file):
try: use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
model = self.load_model(model_file) and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
except Exception as e: current_config = (model_file, opts.SWIN_tile)
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img if use_compile and self._cached_model_config == current_config:
model = model.to(device_swinir, dtype=devices.dtype) model = self._cached_model
else:
self._cached_model = None
try:
model = self.load_model(model_file)
except Exception as e:
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
return img
model = model.to(device_swinir, dtype=devices.dtype)
if use_compile:
model = torch.compile(model)
self._cached_model = model
self._cached_model_config = current_config
img = upscale(img, model) img = upscale(img, model)
devices.torch_gc() devices.torch_gc()
return img return img
@ -170,6 +185,8 @@ def on_ui_settings():
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
script_callbacks.on_ui_settings(on_ui_settings) script_callbacks.on_ui_settings(on_ui_settings)