add option SWIN_torch_compile to accelerate SwinIR upscale using torch.compile()
This commit is contained in:
parent
4da92281f6
commit
44d66daaad
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import platform
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -18,6 +19,8 @@ device_swinir = devices.get_device_for('swinir')
|
||||
|
||||
class UpscalerSwinIR(Upscaler):
|
||||
def __init__(self, dirname):
|
||||
self._cached_model = None # keep the model when SWIN_torch_compile is on to prevent re-compile every runs
|
||||
self._cached_model_config = None # to clear '_cached_model' when changing model (v1/v2) or settings
|
||||
self.name = "SwinIR"
|
||||
self.model_url = SWINIR_MODEL_URL
|
||||
self.model_name = "SwinIR 4x"
|
||||
@ -35,12 +38,24 @@ class UpscalerSwinIR(Upscaler):
|
||||
self.scalers = scalers
|
||||
|
||||
def do_upscale(self, img, model_file):
|
||||
try:
|
||||
model = self.load_model(model_file)
|
||||
except Exception as e:
|
||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \
|
||||
and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows"
|
||||
current_config = (model_file, opts.SWIN_tile)
|
||||
|
||||
if use_compile and self._cached_model_config == current_config:
|
||||
model = self._cached_model
|
||||
else:
|
||||
self._cached_model = None
|
||||
try:
|
||||
model = self.load_model(model_file)
|
||||
except Exception as e:
|
||||
print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
|
||||
return img
|
||||
model = model.to(device_swinir, dtype=devices.dtype)
|
||||
if use_compile:
|
||||
model = torch.compile(model)
|
||||
self._cached_model = model
|
||||
self._cached_model_config = current_config
|
||||
img = upscale(img, model)
|
||||
devices.torch_gc()
|
||||
return img
|
||||
@ -170,6 +185,8 @@ def on_ui_settings():
|
||||
|
||||
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
|
||||
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
|
||||
if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows
|
||||
shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run"))
|
||||
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
|
Loading…
Reference in New Issue
Block a user