update ESRGAN architecture and model to support all ESRGAN models in the DB, BSRGAN and real-ESRGAN models

This commit is contained in:
victorca25 2022-10-09 13:02:12 +02:00 committed by victorca25
parent f4578b343d
commit ad4de819c4
5 changed files with 585 additions and 314 deletions

View File

@ -1,76 +0,0 @@
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
model.to(devices.device_bsrgan)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_bsrgan)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
return model

View File

@ -1,102 +0,0 @@
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

View File

@ -1,80 +0,0 @@
# this file is taken from https://github.com/xinntao/ESRGAN
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
# mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out

View File

@ -5,68 +5,115 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN
if 'conv_first.weight' in pretrained_net:
return pretrained_net
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
if is_realesrgan:
raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
else:
raise Exception("The file is not a ESRGAN model.")
def mod2normal(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
if 'conv_first.weight' in state_dict:
crt_net = {}
items = []
for k, v in state_dict.items():
items.append(k)
crt_net = crt_model.state_dict()
load_net_clean = {}
for k, v in pretrained_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
pretrained_net = load_net_clean
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
tbd = []
for k, v in crt_net.items():
tbd.append(k)
for k in items.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
# directly copy
for k, v in crt_net.items():
if k in pretrained_net and pretrained_net[k].size() == v.size():
crt_net[k] = pretrained_net[k]
tbd.remove(k)
crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight']
crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias']
crt_net['model.3.weight'] = state_dict['upconv1.weight']
crt_net['model.3.bias'] = state_dict['upconv1.bias']
crt_net['model.6.weight'] = state_dict['upconv2.weight']
crt_net['model.6.bias'] = state_dict['upconv2.bias']
crt_net['model.8.weight'] = state_dict['HRconv.weight']
crt_net['model.8.bias'] = state_dict['HRconv.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
for k in tbd.copy():
if 'RDB' in k:
ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[k] = pretrained_net[ori_k]
tbd.remove(k)
def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
crt_net = {}
items = []
for k, v in state_dict.items():
items.append(k)
crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
crt_net['model.0.weight'] = state_dict['conv_first.weight']
crt_net['model.0.bias'] = state_dict['conv_first.bias']
for k in items.copy():
if "rdb" in k:
ori_k = k.replace('body.', 'model.1.sub.')
ori_k = ori_k.replace('.rdb', '.RDB')
if '.weight' in k:
ori_k = ori_k.replace('.weight', '.0.weight')
elif '.bias' in k:
ori_k = ori_k.replace('.bias', '.0.bias')
crt_net[ori_k] = state_dict[k]
items.remove(k)
crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight']
crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias']
crt_net['model.3.weight'] = state_dict['conv_up1.weight']
crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias']
crt_net['model.10.weight'] = state_dict['conv_last.weight']
crt_net['model.10.bias'] = state_dict['conv_last.bias']
state_dict = crt_net
return state_dict
def infer_params(state_dict):
# this code is copied from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
plus = False
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if (part_num > scalemin
and parts[0] == "model"
and parts[2] == "weight"):
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
if not plus and "conv1x1" in block:
plus = True
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
out_nc = out_nc
scale = 2 ** scale2x
return in_nc, out_nc, nf, nb, plus, scale
return crt_net
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
crt_model.load_state_dict(pretrained_net)
crt_model.eval()
if "params_ema" in state_dict:
state_dict = state_dict["params_ema"]
elif "params" in state_dict:
state_dict = state_dict["params"]
num_conv = 16 if "realesr-animevideov3" in filename else 32
model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu')
model.load_state_dict(state_dict)
model.eval()
return model
return crt_model
if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict:
nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23
state_dict = resrgan2normal(state_dict, nb)
elif "conv_first.weight" in state_dict:
state_dict = mod2normal(state_dict)
elif "model.0.weight" not in state_dict:
raise Exception("The file is not a recognized ESRGAN model.")
in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict)
model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus)
model.load_state_dict(state_dict)
model.eval()
return model
def upscale_without_tiling(model, img):
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan)
with torch.no_grad():

View File

@ -0,0 +1,463 @@
# this file is adapted from https://github.com/victorca25/iNNfer
import math
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
####################
# RRDBNet Generator
####################
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None,
act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D',
finalact=None, gaussian_noise=False, plus=False):
super(RRDBNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
self.resrgan_scale = 0
if in_nc % 16 == 0:
self.resrgan_scale = 1
elif in_nc != 4 and in_nc % 4 == 0:
self.resrgan_scale = 2
fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype,
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)]
LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype)
if upsample_mode == 'upconv':
upsample_block = upconv_block
elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)]
HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype)
HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype)
outact = act(finalact) if finalact else None
self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)),
*upsampler, HR_conv0, HR_conv1, outact)
def forward(self, x, outm=None):
if self.resrgan_scale == 1:
feat = pixel_unshuffle(x, scale=4)
elif self.resrgan_scale == 2:
feat = pixel_unshuffle(x, scale=2)
else:
feat = x
return self.model(feat)
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
super(RRDB, self).__init__()
# This is for backwards compatibility with existing models
if nr == 3:
self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus)
else:
RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type,
norm_type, act_type, mode, convtype, spectral_norm=spectral_norm,
gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)]
self.RDBs = nn.Sequential(*RDB_list)
def forward(self, x):
if hasattr(self, 'RDB1'):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
else:
out = self.RDBs(x)
return out * 0.2 + x
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero',
norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D',
spectral_norm=False, gaussian_noise=False, plus=False):
super(ResidualDenseBlock_5C, self).__init__()
self.noise = GaussianNoise() if gaussian_noise else None
self.conv1x1 = conv1x1(nf, gc) if plus else None
self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type,
norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype,
spectral_norm=spectral_norm)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
if self.conv1x1:
x2 = x2 + self.conv1x1(x)
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
if self.conv1x1:
x4 = x4 + x2
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
if self.noise:
return self.noise(x5.mul(0.2) + x)
else:
return x5 * 0.2 + x
####################
# ESRGANplus
####################
class GaussianNoise(nn.Module):
def __init__(self, sigma=0.1, is_relative_detach=False):
super().__init__()
self.sigma = sigma
self.is_relative_detach = is_relative_detach
self.noise = torch.tensor(0, dtype=torch.float)
def forward(self, x):
if self.training and self.sigma != 0:
self.noise = self.noise.to(x.device)
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise
return x
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
####################
# SRVGGNetCompact
####################
class SRVGGNetCompact(nn.Module):
"""A compact VGG-style network structure for super-resolution.
This class is copied from https://github.com/xinntao/Real-ESRGAN
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
super(SRVGGNetCompact, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.upscale = upscale
self.act_type = act_type
self.body = nn.ModuleList()
# the first conv
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
# the first activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the body structure
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
# activation
if act_type == 'relu':
activation = nn.ReLU(inplace=True)
elif act_type == 'prelu':
activation = nn.PReLU(num_parameters=num_feat)
elif act_type == 'leakyrelu':
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.body.append(activation)
# the last conv
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
# upsample
self.upsampler = nn.PixelShuffle(upscale)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = self.upsampler(out)
# add the nearest upsampled image, so that the network learns the residual
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
out += base
return out
####################
# Upsampler
####################
class Upsample(nn.Module):
r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
The input data is assumed to be of the form
`minibatch x channels x [optional depth] x [optional height] x width`.
"""
def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None):
super(Upsample, self).__init__()
if isinstance(scale_factor, tuple):
self.scale_factor = tuple(float(factor) for factor in scale_factor)
else:
self.scale_factor = float(scale_factor) if scale_factor else None
self.mode = mode
self.size = size
self.align_corners = align_corners
def forward(self, x):
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
def extra_repr(self):
if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor)
else:
info = 'size=' + str(self.size)
info += ', mode=' + self.mode
return info
def pixel_unshuffle(x, scale):
""" Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
scale (int): Downsample ratio.
Returns:
Tensor: the pixel unshuffled feature.
"""
b, c, hh, hw = x.size()
out_channel = c * (scale**2)
assert hh % scale == 0 and hw % scale == 0
h = hh // scale
w = hw // scale
x_view = x.view(b, c, h, scale, w, scale)
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'):
""" Upconv layer """
upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor
upsample = Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype)
return sequential(upsample, conv)
####################
# Basic blocks
####################
def make_layer(basic_block, num_basic_block, **kwarg):
"""Make layers by stacking the same blocks.
Args:
basic_block (nn.module): nn.module class for basic block. (block)
num_basic_block (int): number of blocks. (n_layers)
Returns:
nn.Sequential: Stacked blocks in nn.Sequential.
"""
layers = []
for _ in range(num_basic_block):
layers.append(basic_block(**kwarg))
return nn.Sequential(*layers)
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
""" activation helper """
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type in ('leakyrelu', 'lrelu'):
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act_type == 'tanh': # [-1, 1] range output
layer = nn.Tanh()
elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid()
else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
return layer
class Identity(nn.Module):
def __init__(self, *kwargs):
super(Identity, self).__init__()
def forward(self, x, *kwargs):
return x
def norm(norm_type, nc):
""" Return a normalization layer """
norm_type = norm_type.lower()
if norm_type == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
return layer
def pad(pad_type, padding):
""" padding layer helper """
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding)
else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ShortcutBlock(nn.Module):
""" Elementwise sum the output of a submodule to its input """
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|')
def sequential(*args):
""" Flatten Sequential. It unwraps nn.Sequential. """
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False):
""" Conv layer with padding, normalization, activation """
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
if convtype=='PartialConv2D':
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='DeformConv2D':
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
elif convtype=='Conv3D':
c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
else:
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups)
if spectral_norm:
c = nn.utils.spectral_norm(c)
a = act(act_type) if act_type else None
if 'CNA' in mode:
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == 'NAC':
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)