Merge pull request #10285 from akx/ruff-spacing
Indentation + ruff whitespace fixes
This commit is contained in:
commit
abe32cefa3
@ -130,11 +130,11 @@ class LDSR:
|
|||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
|
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface):
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
@ -890,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
|
@ -265,4 +265,4 @@ class SCUNet(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
@ -150,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
x = self.check_image_size(x)
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
self.mean = self.mean.type_as(x)
|
||||||
x = (x - self.mean) * self.img_range
|
x = (x - self.mean) * self.img_range
|
||||||
|
|
||||||
|
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
def calculate_mask(self, x_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
def forward(self, x, x_size):
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
else:
|
else:
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
flops += H * W * self.dim // 2
|
flops += H * W * self.dim // 2
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
""" A basic Swin Transformer layer for one stage.
|
""" A basic Swin Transformer layer for one stage.
|
||||||
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
|
|||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
nn.init.constant_(blk.norm2.bias, 0)
|
nn.init.constant_(blk.norm2.bias, 0)
|
||||||
nn.init.constant_(blk.norm2.weight, 0)
|
nn.init.constant_(blk.norm2.weight, 0)
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
r""" Image to Patch Embedding
|
r""" Image to Patch Embedding
|
||||||
Args:
|
Args:
|
||||||
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
|
|||||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
flops += Ho * Wo * self.embed_dim
|
flops += Ho * Wo * self.embed_dim
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
class RSTB(nn.Module):
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
"""Residual Swin Transformer Block (RSTB).
|
||||||
@ -531,7 +531,7 @@ class RSTB(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop, attn_drop=attn_drop,
|
drop=drop, attn_drop=attn_drop,
|
||||||
drop_path=drop_path,
|
drop_path=drop_path,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample, self).__init__(*m)
|
super(Upsample, self).__init__(*m)
|
||||||
|
|
||||||
class Upsample_hf(nn.Sequential):
|
class Upsample_hf(nn.Sequential):
|
||||||
"""Upsample module.
|
"""Upsample module.
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
|
|||||||
m.append(nn.PixelShuffle(3))
|
m.append(nn.PixelShuffle(3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample_hf, self).__init__(*m)
|
super(Upsample_hf, self).__init__(*m)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
class UpsampleOneStep(nn.Sequential):
|
||||||
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
flops = H * W * self.num_feat * 3 * 9
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Swin2SR(nn.Module):
|
class Swin2SR(nn.Module):
|
||||||
r""" Swin2SR
|
r""" Swin2SR
|
||||||
@ -699,7 +699,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle_hf':
|
if self.upsampler == 'pixelshuffle_hf':
|
||||||
self.layers_hf = nn.ModuleList()
|
self.layers_hf = nn.ModuleList()
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers_hf.append(layer)
|
self.layers_hf.append(layer)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
# build the last conv layer in deep feature extraction
|
||||||
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
|
|||||||
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
self.conv_after_aux = nn.Sequential(
|
self.conv_after_aux = nn.Sequential(
|
||||||
nn.Conv2d(3, num_feat, 3, 1, 1),
|
nn.Conv2d(3, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
self.upsample = Upsample(upscale, num_feat)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffle_hf':
|
elif self.upsampler == 'pixelshuffle_hf':
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
|
|||||||
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
elif self.upsampler == 'pixelshuffledirect':
|
||||||
# for lightweight SR (to save parameters)
|
# for lightweight SR (to save parameters)
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||||
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_features_hf(self, x):
|
def forward_features_hf(self, x):
|
||||||
x_size = (x.shape[2], x.shape[3])
|
x_size = (x.shape[2], x.shape[3])
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.norm(x) # B L C
|
x = self.norm(x) # B L C
|
||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
x = self.conv_after_body(self.forward_features(x)) + x
|
||||||
x_before = self.conv_before_upsample(x)
|
x_before = self.conv_before_upsample(x)
|
||||||
x_out = self.conv_last(self.upsample(x_before))
|
x_out = self.conv_last(self.upsample(x_before))
|
||||||
|
|
||||||
x_hf = self.conv_first_hf(x_before)
|
x_hf = self.conv_first_hf(x_before)
|
||||||
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
||||||
x_hf = self.conv_before_upsample_hf(x_hf)
|
x_hf = self.conv_before_upsample_hf(x_hf)
|
||||||
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
|
|||||||
x_first = self.conv_first(x)
|
x_first = self.conv_first(x)
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||||
x = x + self.conv_last(res)
|
x = x + self.conv_last(res)
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
x = x / self.img_range + self.mean
|
||||||
if self.upsampler == "pixelshuffle_aux":
|
if self.upsampler == "pixelshuffle_aux":
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
||||||
|
|
||||||
elif self.upsampler == "pixelshuffle_hf":
|
elif self.upsampler == "pixelshuffle_hf":
|
||||||
x_out = x_out / self.img_range + self.mean
|
x_out = x_out / self.img_range + self.mean
|
||||||
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
x = torch.randn((1, 3, height, width))
|
||||||
x = model(x)
|
x = model(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
@ -327,7 +327,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
@ -227,7 +227,7 @@ class Api:
|
|||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
def get_scripts_list(self):
|
def get_scripts_list(self):
|
||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
||||||
@ -237,7 +237,7 @@ class Api:
|
|||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
return script_runner.scripts[script_idx]
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
|
@ -289,4 +289,4 @@ class MemoryResponse(BaseModel):
|
|||||||
|
|
||||||
class ScriptsList(BaseModel):
|
class ScriptsList(BaseModel):
|
||||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
||||||
|
@ -102,4 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
|
|||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
|
@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
|||||||
tgt_mask: Optional[Tensor] = None,
|
tgt_mask: Optional[Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
query_pos: Optional[Tensor] = None):
|
query_pos: Optional[Tensor] = None):
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
tgt2 = self.norm1(tgt)
|
tgt2 = self.norm1(tgt)
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=('32', '64', '128', '256'),
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=('quantize', 'generator')):
|
fix_modules=('quantize', 'generator')):
|
||||||
@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||||
for _ in range(self.n_layers)])
|
for _ in range(self.n_layers)])
|
||||||
|
|
||||||
# logits_predict head
|
# logits_predict head
|
||||||
self.idx_pred_layer = nn.Sequential(
|
self.idx_pred_layer = nn.Sequential(
|
||||||
nn.LayerNorm(dim_embd),
|
nn.LayerNorm(dim_embd),
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||||
|
|
||||||
self.channels = {
|
self.channels = {
|
||||||
'16': 512,
|
'16': 512,
|
||||||
'32': 256,
|
'32': 256,
|
||||||
@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
enc_feat_dict = {}
|
enc_feat_dict = {}
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
for i, block in enumerate(self.encoder.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in out_list:
|
if i in out_list:
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||||
|
|
||||||
@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
for i, block in enumerate(self.generator.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in fuse_list: # fuse after i-th block
|
if i in fuse_list: # fuse after i-th block
|
||||||
f_size = str(x.shape[-1])
|
f_size = str(x.shape[-1])
|
||||||
if w>0:
|
if w>0:
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||||
out = x
|
out = x
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
# logits doesn't need softmax before cross_entropy loss
|
||||||
return out, logits, lq_feat
|
return out, logits, lq_feat
|
||||||
|
@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY
|
|||||||
|
|
||||||
def normalize(in_channels):
|
def normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def swish(x):
|
def swish(x):
|
||||||
@ -210,15 +210,15 @@ class AttnBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h*w)
|
q = q.reshape(b, c, h*w)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(b, c, h*w)
|
k = k.reshape(b, c, h*w)
|
||||||
w_ = torch.bmm(q, k)
|
w_ = torch.bmm(q, k)
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
w_ = F.softmax(w_, dim=2)
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = v.reshape(b, c, h*w)
|
v = v.reshape(b, c, h*w)
|
||||||
w_ = w_.permute(0, 2, 1)
|
w_ = w_.permute(0, 2, 1)
|
||||||
h_ = torch.bmm(v, w_)
|
h_ = torch.bmm(v, w_)
|
||||||
h_ = h_.reshape(b, c, h, w)
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
@ -270,18 +270,18 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.num_resolutions = len(self.ch_mult)
|
self.num_resolutions = len(self.ch_mult)
|
||||||
self.num_res_blocks = res_blocks
|
self.num_res_blocks = res_blocks
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions
|
||||||
self.in_channels = emb_dim
|
self.in_channels = emb_dim
|
||||||
self.out_channels = 3
|
self.out_channels = 3
|
||||||
@ -315,24 +315,24 @@ class Generator(nn.Module):
|
|||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
self.in_channels = 3
|
self.in_channels = 3
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.n_blocks = res_blocks
|
self.n_blocks = res_blocks
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.kl_weight
|
self.kl_weight
|
||||||
)
|
)
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
self.nf,
|
self.nf,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.ch_mult,
|
self.ch_mult,
|
||||||
self.n_blocks,
|
self.n_blocks,
|
||||||
self.resolution,
|
self.resolution,
|
||||||
self.attn_resolutions
|
self.attn_resolutions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
|||||||
raise ValueError('Wrong params!')
|
raise ValueError('Wrong params!')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.main(x)
|
return self.main(x)
|
||||||
|
@ -105,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
Modified options that can be used:
|
Modified options that can be used:
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ class GaussianNoise(nn.Module):
|
|||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
x = x + sampled_noise
|
x = x + sampled_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
|
@ -540,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
@ -593,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
@ -636,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -657,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
loss_logging.append(_loss_step)
|
loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
@ -674,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
_loss_step = 0
|
_loss_step = 0
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
@ -367,7 +367,7 @@ class FilenameGenerator:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def hasprompt(self, *args):
|
def hasprompt(self, *args):
|
||||||
lower = self.prompt.lower()
|
lower = self.prompt.lower()
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
|
@ -42,7 +42,7 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
@ -60,4 +60,4 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
if platform.processor() == 'i386':
|
if platform.processor() == 'i386':
|
||||||
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||||
|
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
|
@ -13,7 +13,7 @@ def connect(token, port, region):
|
|||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
|
|
||||||
# Guard for existing tunnels
|
# Guard for existing tunnels
|
||||||
existing = ngrok.get_tunnels(pyngrok_config=config)
|
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||||
if existing:
|
if existing:
|
||||||
@ -24,7 +24,7 @@ def connect(token, port, region):
|
|||||||
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||||
'You can use this link after the launch is complete.')
|
'You can use this link after the launch is complete.')
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if account is None:
|
if account is None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
@ -164,7 +164,7 @@ class StableDiffusionProcessing:
|
|||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
self.is_hr_pass = False
|
self.is_hr_pass = False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
|
@ -32,22 +32,22 @@ class CFGDenoiserParams:
|
|||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
self.image_cond = image_cond
|
self.image_cond = image_cond
|
||||||
"""Conditioning image"""
|
"""Conditioning image"""
|
||||||
|
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
"""Current sigma noise step value"""
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
self.sampling_step = sampling_step
|
self.sampling_step = sampling_step
|
||||||
"""Current Sampling step number"""
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
self.text_cond = text_cond
|
self.text_cond = text_cond
|
||||||
""" Encoder hidden states of text conditioning from prompt"""
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
@ -240,7 +240,7 @@ def add_callback(callbacks, fun):
|
|||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
|
@ -34,7 +34,7 @@ def apply_optimizations():
|
|||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||||
@ -92,12 +92,12 @@ def fix_checkpoint():
|
|||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
#Calculate the weight normally, but ignore the mean
|
#Calculate the weight normally, but ignore the mean
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
|
||||||
#Check if we have weights available
|
#Check if we have weights available
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
loss *= weight
|
loss *= weight
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
#Return the loss, as mean if specified
|
||||||
return loss.mean() if mean else loss
|
return loss.mean() if mean else loss
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
#Temporarily append weights to a place accessible during loss calc
|
||||||
sd_model._custom_loss_weight = w
|
sd_model._custom_loss_weight = w
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
if not hasattr(sd_model, '_old_get_loss'):
|
||||||
@ -120,7 +120,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
del sd_model._custom_loss_weight
|
del sd_model._custom_loss_weight
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
#If we have an old loss function, reset the loss function to the original one
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
if hasattr(sd_model, '_old_get_loss'):
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
sd_model.get_loss = sd_model._old_get_loss
|
||||||
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
end = i + 2
|
end = i + 2
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
s1 *= self.scale
|
s1 *= self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
del q, k, v
|
del q, k, v
|
||||||
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k_in = k_in * self.scale
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = r1.to(dtype)
|
r1 = r1.to(dtype)
|
||||||
@ -228,7 +228,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
|
|
||||||
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
r = r.to(dtype)
|
r = r.to(dtype)
|
||||||
@ -369,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -451,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
|
@ -165,7 +165,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -372,7 +372,7 @@ def enable_midas_autodownload():
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
if not os.path.exists(midas_path):
|
if not os.path.exists(midas_path):
|
||||||
mkdir(midas_path)
|
mkdir(midas_path)
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
print(f"{model_type} downloaded")
|
print(f"{model_type} downloaded")
|
||||||
|
@ -93,10 +93,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
if shared.sd_model.model.conditioning_key == "crossattn-adm":
|
||||||
image_uncond = torch.zeros_like(image_cond)
|
image_uncond = torch.zeros_like(image_cond)
|
||||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
|
||||||
else:
|
else:
|
||||||
image_uncond = image_cond
|
image_uncond = image_cond
|
||||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
|
||||||
|
|
||||||
if not is_edit_model:
|
if not is_edit_model:
|
||||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||||
@ -316,7 +316,7 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
sigma_sched = sigmas[steps - t_enc - 1:]
|
sigma_sched = sigmas[steps - t_enc - 1:]
|
||||||
xi = x + noise * sigma_sched[0]
|
xi = x + noise * sigma_sched[0]
|
||||||
|
|
||||||
extra_params_kwargs = self.initialize(p)
|
extra_params_kwargs = self.initialize(p)
|
||||||
parameters = inspect.signature(self.func).parameters
|
parameters = inspect.signature(self.func).parameters
|
||||||
|
|
||||||
@ -339,9 +339,9 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg.init_latent = x
|
self.model_wrap_cfg.init_latent = x
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
extra_args={
|
extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
@ -374,9 +374,9 @@ class KDiffusionSampler:
|
|||||||
|
|
||||||
self.last_latent = x
|
self.last_latent = x
|
||||||
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
|
||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
's_min_uncond': self.s_min_uncond
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
@ -179,7 +179,7 @@ def efficient_dot_product_attention(
|
|||||||
chunk_idx,
|
chunk_idx,
|
||||||
min(query_chunk_size, q_tokens)
|
min(query_chunk_size, q_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||||
|
@ -10,63 +10,64 @@ RED = "#F00"
|
|||||||
|
|
||||||
|
|
||||||
def crop_image(im, settings):
|
def crop_image(im, settings):
|
||||||
""" Intelligently crop an image to the subject matter """
|
""" Intelligently crop an image to the subject matter """
|
||||||
|
|
||||||
scale_by = 1
|
scale_by = 1
|
||||||
if is_landscape(im.width, im.height):
|
if is_landscape(im.width, im.height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
elif is_portrait(im.width, im.height):
|
elif is_portrait(im.width, im.height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_square(im.width, im.height):
|
elif is_square(im.width, im.height):
|
||||||
if is_square(settings.crop_width, settings.crop_height):
|
if is_square(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_landscape(settings.crop_width, settings.crop_height):
|
elif is_landscape(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_width / im.width
|
scale_by = settings.crop_width / im.width
|
||||||
elif is_portrait(settings.crop_width, settings.crop_height):
|
elif is_portrait(settings.crop_width, settings.crop_height):
|
||||||
scale_by = settings.crop_height / im.height
|
scale_by = settings.crop_height / im.height
|
||||||
|
|
||||||
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
|
||||||
im_debug = im.copy()
|
|
||||||
|
|
||||||
focus = focal_point(im_debug, settings)
|
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
|
||||||
|
im_debug = im.copy()
|
||||||
|
|
||||||
# take the focal point and turn it into crop coordinates that try to center over the focal
|
focus = focal_point(im_debug, settings)
|
||||||
# point but then get adjusted back into the frame
|
|
||||||
y_half = int(settings.crop_height / 2)
|
|
||||||
x_half = int(settings.crop_width / 2)
|
|
||||||
|
|
||||||
x1 = focus.x - x_half
|
# take the focal point and turn it into crop coordinates that try to center over the focal
|
||||||
if x1 < 0:
|
# point but then get adjusted back into the frame
|
||||||
x1 = 0
|
y_half = int(settings.crop_height / 2)
|
||||||
elif x1 + settings.crop_width > im.width:
|
x_half = int(settings.crop_width / 2)
|
||||||
x1 = im.width - settings.crop_width
|
|
||||||
|
|
||||||
y1 = focus.y - y_half
|
x1 = focus.x - x_half
|
||||||
if y1 < 0:
|
if x1 < 0:
|
||||||
y1 = 0
|
x1 = 0
|
||||||
elif y1 + settings.crop_height > im.height:
|
elif x1 + settings.crop_width > im.width:
|
||||||
y1 = im.height - settings.crop_height
|
x1 = im.width - settings.crop_width
|
||||||
|
|
||||||
x2 = x1 + settings.crop_width
|
y1 = focus.y - y_half
|
||||||
y2 = y1 + settings.crop_height
|
if y1 < 0:
|
||||||
|
y1 = 0
|
||||||
|
elif y1 + settings.crop_height > im.height:
|
||||||
|
y1 = im.height - settings.crop_height
|
||||||
|
|
||||||
crop = [x1, y1, x2, y2]
|
x2 = x1 + settings.crop_width
|
||||||
|
y2 = y1 + settings.crop_height
|
||||||
|
|
||||||
results = []
|
crop = [x1, y1, x2, y2]
|
||||||
|
|
||||||
results.append(im.crop(tuple(crop)))
|
results = []
|
||||||
|
|
||||||
if settings.annotate_image:
|
results.append(im.crop(tuple(crop)))
|
||||||
d = ImageDraw.Draw(im_debug)
|
|
||||||
rect = list(crop)
|
|
||||||
rect[2] -= 1
|
|
||||||
rect[3] -= 1
|
|
||||||
d.rectangle(rect, outline=GREEN)
|
|
||||||
results.append(im_debug)
|
|
||||||
if settings.destop_view_image:
|
|
||||||
im_debug.show()
|
|
||||||
|
|
||||||
return results
|
if settings.annotate_image:
|
||||||
|
d = ImageDraw.Draw(im_debug)
|
||||||
|
rect = list(crop)
|
||||||
|
rect[2] -= 1
|
||||||
|
rect[3] -= 1
|
||||||
|
d.rectangle(rect, outline=GREEN)
|
||||||
|
results.append(im_debug)
|
||||||
|
if settings.destop_view_image:
|
||||||
|
im_debug.show()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def focal_point(im, settings):
|
def focal_point(im, settings):
|
||||||
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
|
||||||
@ -86,7 +87,7 @@ def focal_point(im, settings):
|
|||||||
corner_centroid = None
|
corner_centroid = None
|
||||||
if len(corner_points) > 0:
|
if len(corner_points) > 0:
|
||||||
corner_centroid = centroid(corner_points)
|
corner_centroid = centroid(corner_points)
|
||||||
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
corner_centroid.weight = settings.corner_points_weight / weight_pref_total
|
||||||
pois.append(corner_centroid)
|
pois.append(corner_centroid)
|
||||||
|
|
||||||
entropy_centroid = None
|
entropy_centroid = None
|
||||||
@ -98,7 +99,7 @@ def focal_point(im, settings):
|
|||||||
face_centroid = None
|
face_centroid = None
|
||||||
if len(face_points) > 0:
|
if len(face_points) > 0:
|
||||||
face_centroid = centroid(face_points)
|
face_centroid = centroid(face_points)
|
||||||
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
face_centroid.weight = settings.face_points_weight / weight_pref_total
|
||||||
pois.append(face_centroid)
|
pois.append(face_centroid)
|
||||||
|
|
||||||
average_point = poi_average(pois, settings)
|
average_point = poi_average(pois, settings)
|
||||||
@ -132,7 +133,7 @@ def focal_point(im, settings):
|
|||||||
d.rectangle(f.bounding(4), outline=color)
|
d.rectangle(f.bounding(4), outline=color)
|
||||||
|
|
||||||
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
d.ellipse(average_point.bounding(max_size), outline=GREEN)
|
||||||
|
|
||||||
return average_point
|
return average_point
|
||||||
|
|
||||||
|
|
||||||
@ -260,10 +261,11 @@ def image_entropy(im):
|
|||||||
hist = hist[hist > 0]
|
hist = hist[hist > 0]
|
||||||
return -np.log2(hist / hist.sum()).sum()
|
return -np.log2(hist / hist.sum()).sum()
|
||||||
|
|
||||||
|
|
||||||
def centroid(pois):
|
def centroid(pois):
|
||||||
x = [poi.x for poi in pois]
|
x = [poi.x for poi in pois]
|
||||||
y = [poi.y for poi in pois]
|
y = [poi.y for poi in pois]
|
||||||
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois))
|
return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
|
||||||
|
|
||||||
|
|
||||||
def poi_average(pois, settings):
|
def poi_average(pois, settings):
|
||||||
@ -281,59 +283,59 @@ def poi_average(pois, settings):
|
|||||||
|
|
||||||
|
|
||||||
def is_landscape(w, h):
|
def is_landscape(w, h):
|
||||||
return w > h
|
return w > h
|
||||||
|
|
||||||
|
|
||||||
def is_portrait(w, h):
|
def is_portrait(w, h):
|
||||||
return h > w
|
return h > w
|
||||||
|
|
||||||
|
|
||||||
def is_square(w, h):
|
def is_square(w, h):
|
||||||
return w == h
|
return w == h
|
||||||
|
|
||||||
|
|
||||||
def download_and_cache_models(dirname):
|
def download_and_cache_models(dirname):
|
||||||
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
|
||||||
model_file_name = 'face_detection_yunet.onnx'
|
model_file_name = 'face_detection_yunet.onnx'
|
||||||
|
|
||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
cache_file = os.path.join(dirname, model_file_name)
|
cache_file = os.path.join(dirname, model_file_name)
|
||||||
if not os.path.exists(cache_file):
|
if not os.path.exists(cache_file):
|
||||||
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
|
||||||
response = requests.get(download_url)
|
response = requests.get(download_url)
|
||||||
with open(cache_file, "wb") as f:
|
with open(cache_file, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
|
|
||||||
if os.path.exists(cache_file):
|
if os.path.exists(cache_file):
|
||||||
return cache_file
|
return cache_file
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class PointOfInterest:
|
class PointOfInterest:
|
||||||
def __init__(self, x, y, weight=1.0, size=10):
|
def __init__(self, x, y, weight=1.0, size=10):
|
||||||
self.x = x
|
self.x = x
|
||||||
self.y = y
|
self.y = y
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def bounding(self, size):
|
def bounding(self, size):
|
||||||
return [
|
return [
|
||||||
self.x - size//2,
|
self.x - size // 2,
|
||||||
self.y - size//2,
|
self.y - size // 2,
|
||||||
self.x + size//2,
|
self.x + size // 2,
|
||||||
self.y + size//2
|
self.y + size // 2
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings:
|
class Settings:
|
||||||
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
|
||||||
self.crop_width = crop_width
|
self.crop_width = crop_width
|
||||||
self.crop_height = crop_height
|
self.crop_height = crop_height
|
||||||
self.corner_points_weight = corner_points_weight
|
self.corner_points_weight = corner_points_weight
|
||||||
self.entropy_points_weight = entropy_points_weight
|
self.entropy_points_weight = entropy_points_weight
|
||||||
self.face_points_weight = face_points_weight
|
self.face_points_weight = face_points_weight
|
||||||
self.annotate_image = annotate_image
|
self.annotate_image = annotate_image
|
||||||
self.destop_view_image = False
|
self.destop_view_image = False
|
||||||
self.dnn_model_path = dnn_model_path
|
self.dnn_model_path = dnn_model_path
|
||||||
|
@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
|
|||||||
weight = torch.ones(latent_sample.shape)
|
weight = torch.ones(latent_sample.shape)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
|
|
||||||
if latent_sampling_method == "random":
|
if latent_sampling_method == "random":
|
||||||
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
|
||||||
else:
|
else:
|
||||||
@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def collate_wrapper_random(batch):
|
def collate_wrapper_random(batch):
|
||||||
return BatchLoaderRandom(batch)
|
return BatchLoaderRandom(batch)
|
||||||
|
@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
|
|||||||
default=None
|
default=None
|
||||||
)
|
)
|
||||||
return wh and center_crop(image, *wh)
|
return wh and center_crop(image, *wh)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||||
width = process_width
|
width = process_width
|
||||||
|
@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
|
|||||||
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
|
||||||
|
|
||||||
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
|
||||||
tensorboard_writer.add_scalar(tag=tag,
|
tensorboard_writer.add_scalar(tag=tag,
|
||||||
scalar_value=value, global_step=step)
|
scalar_value=value, global_step=step)
|
||||||
|
|
||||||
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
|
||||||
# Convert a pil image to a torch tensor
|
# Convert a pil image to a torch tensor
|
||||||
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
|
||||||
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
|
||||||
len(pil_image.getbands()))
|
len(pil_image.getbands()))
|
||||||
img_tensor = img_tensor.permute((2, 0, 1))
|
img_tensor = img_tensor.permute((2, 0, 1))
|
||||||
|
|
||||||
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
|
||||||
|
|
||||||
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
|
||||||
@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
if initial_step >= steps:
|
if initial_step >= steps:
|
||||||
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
shared.state.textinfo = "Model has already been trained beyond specified max steps"
|
||||||
return embedding, filename
|
return embedding, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
|
||||||
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
|
||||||
@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
# dataset loading may take a while, so input validations and early returns should be done before this
|
# dataset loading may take a while, so input validations and early returns should be done before this
|
||||||
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
||||||
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
old_parallel_processing_allowed = shared.parallel_processing_allowed
|
||||||
|
|
||||||
if shared.opts.training_enable_tensorboard:
|
if shared.opts.training_enable_tensorboard:
|
||||||
tensorboard_writer = tensorboard_setup(log_directory)
|
tensorboard_writer = tensorboard_setup(log_directory)
|
||||||
|
|
||||||
@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
|
||||||
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
|
||||||
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
|
||||||
|
|
||||||
if optimizer_state_dict is not None:
|
if optimizer_state_dict is not None:
|
||||||
optimizer.load_state_dict(optimizer_state_dict)
|
optimizer.load_state_dict(optimizer_state_dict)
|
||||||
print("Loaded existing optimizer from checkpoint")
|
print("Loaded existing optimizer from checkpoint")
|
||||||
@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(embedding.step)
|
clip_grad_sched.step(embedding.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
|
|||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
clip_grad(embedding.vec, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
|
@ -1171,7 +1171,7 @@ def create_ui():
|
|||||||
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
|
||||||
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
|
||||||
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
|
||||||
|
|
||||||
with gr.Column(visible=False) as process_multicrop_col:
|
with gr.Column(visible=False) as process_multicrop_col:
|
||||||
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -1183,7 +1183,7 @@ def create_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
|
||||||
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=3):
|
with gr.Column(scale=3):
|
||||||
gr.HTML(value="")
|
gr.HTML(value="")
|
||||||
@ -1226,7 +1226,7 @@ def create_ui():
|
|||||||
with FormRow():
|
with FormRow():
|
||||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||||
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
|
||||||
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
|
||||||
@ -1565,7 +1565,7 @@ def create_ui():
|
|||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
|
|
||||||
|
|
||||||
def unload_sd_weights():
|
def unload_sd_weights():
|
||||||
modules.sd_models.unload_model_weights()
|
modules.sd_models.unload_model_weights()
|
||||||
@ -1841,15 +1841,15 @@ def versions_html():
|
|||||||
|
|
||||||
return f"""
|
return f"""
|
||||||
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
|
||||||
•
|
•
|
||||||
python: <span title="{sys.version}">{python_version}</span>
|
python: <span title="{sys.version}">{python_version}</span>
|
||||||
•
|
•
|
||||||
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
||||||
•
|
•
|
||||||
xformers: {xformers_version}
|
xformers: {xformers_version}
|
||||||
•
|
•
|
||||||
gradio: {gr.__version__}
|
gradio: {gr.__version__}
|
||||||
•
|
•
|
||||||
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
checkpoint: <a id="sd_checkpoint_hash">N/A</a>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -467,7 +467,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
|
||||||
<td>{install_code}</td>
|
<td>{install_code}</td>
|
||||||
</tr>
|
</tr>
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for tag in [x for x in extension_tags if x not in tags]:
|
for tag in [x for x in extension_tags if x not in tags]:
|
||||||
@ -535,9 +535,9 @@ def create_ui():
|
|||||||
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
|
||||||
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
search_extensions_text = gr.Text(label="Search").style(container=False)
|
search_extensions_text = gr.Text(label="Search").style(container=False)
|
||||||
|
|
||||||
install_result = gr.HTML()
|
install_result = gr.HTML()
|
||||||
available_extensions_table = gr.HTML()
|
available_extensions_table = gr.HTML()
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
config_class = BertSeriesConfig
|
config_class = BertSeriesConfig
|
||||||
|
|
||||||
def __init__(self, config=None, **kargs):
|
def __init__(self, config=None, **kargs):
|
||||||
# modify initialization for autoloading
|
# modify initialization for autoloading
|
||||||
if config is None:
|
if config is None:
|
||||||
config = XLMRobertaConfig()
|
config = XLMRobertaConfig()
|
||||||
config.attention_probs_dropout_prob= 0.1
|
config.attention_probs_dropout_prob= 0.1
|
||||||
@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
text["attention_mask"] = torch.tensor(
|
text["attention_mask"] = torch.tensor(
|
||||||
text['attention_mask']).to(device)
|
text['attention_mask']).to(device)
|
||||||
features = self(**text)
|
features = self(**text)
|
||||||
return features['projection_state']
|
return features['projection_state']
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
|
|||||||
|
|
||||||
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
|
||||||
base_model_prefix = 'roberta'
|
base_model_prefix = 'roberta'
|
||||||
config_class= RobertaSeriesConfig
|
config_class= RobertaSeriesConfig
|
||||||
|
@ -6,6 +6,7 @@ extend-select = [
|
|||||||
"B",
|
"B",
|
||||||
"C",
|
"C",
|
||||||
"I",
|
"I",
|
||||||
|
"W",
|
||||||
]
|
]
|
||||||
|
|
||||||
exclude = [
|
exclude = [
|
||||||
@ -20,7 +21,7 @@ ignore = [
|
|||||||
"I001", # Import block is un-sorted or un-formatted
|
"I001", # Import block is un-sorted or un-formatted
|
||||||
"C901", # Function is too complex
|
"C901", # Function is too complex
|
||||||
"C408", # Rewrite as a literal
|
"C408", # Rewrite as a literal
|
||||||
|
"W605", # invalid escape sequence, messes with some docstrings
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.per-file-ignores]
|
[tool.ruff.per-file-ignores]
|
||||||
@ -28,4 +29,4 @@ ignore = [
|
|||||||
|
|
||||||
[tool.ruff.flake8-bugbear]
|
[tool.ruff.flake8-bugbear]
|
||||||
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
|
||||||
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]
|
||||||
|
@ -149,9 +149,9 @@ class Script(scripts.Script):
|
|||||||
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
|
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
|
||||||
|
|
||||||
return [
|
return [
|
||||||
info,
|
info,
|
||||||
override_sampler,
|
override_sampler,
|
||||||
override_prompt, original_prompt, original_negative_prompt,
|
override_prompt, original_prompt, original_negative_prompt,
|
||||||
override_steps, st,
|
override_steps, st,
|
||||||
override_strength,
|
override_strength,
|
||||||
cfg, randomness, sigma_adjustment,
|
cfg, randomness, sigma_adjustment,
|
||||||
@ -191,17 +191,17 @@ class Script(scripts.Script):
|
|||||||
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
|
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
|
||||||
|
|
||||||
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
|
||||||
|
|
||||||
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
|
||||||
|
|
||||||
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
|
||||||
|
|
||||||
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
sigmas = sampler.model_wrap.get_sigmas(p.steps)
|
||||||
|
|
||||||
noise_dt = combined_noise - (p.init_latent / sigmas[0])
|
noise_dt = combined_noise - (p.init_latent / sigmas[0])
|
||||||
|
|
||||||
p.seed = p.seed + 1
|
p.seed = p.seed + 1
|
||||||
|
|
||||||
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
|
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
|
||||||
|
|
||||||
p.sample = sample_extra
|
p.sample = sample_extra
|
||||||
|
@ -14,7 +14,7 @@ class Script(scripts.Script):
|
|||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
return is_img2img
|
return is_img2img
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
|
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
|
||||||
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
|
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
|
||||||
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
|
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
|
||||||
@ -104,7 +104,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
p.seed = processed.seed + 1
|
p.seed = processed.seed + 1
|
||||||
p.denoising_strength = calculate_denoising_strength(i + 1)
|
p.denoising_strength = calculate_denoising_strength(i + 1)
|
||||||
|
|
||||||
if state.skipped:
|
if state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -121,7 +121,7 @@ class Script(scripts.Script):
|
|||||||
all_images.append(last_image)
|
all_images.append(last_image)
|
||||||
|
|
||||||
p.inpainting_fill = original_inpainting_fill
|
p.inpainting_fill = original_inpainting_fill
|
||||||
|
|
||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -132,7 +132,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
if opts.return_grid:
|
if opts.return_grid:
|
||||||
grids.append(grid)
|
grids.append(grid)
|
||||||
|
|
||||||
all_images = grids + all_images
|
all_images = grids + all_images
|
||||||
|
|
||||||
processed = Processed(p, all_images, initial_seed, initial_info)
|
processed = Processed(p, all_images, initial_seed, initial_info)
|
||||||
|
@ -19,7 +19,7 @@ class Script(scripts.Script):
|
|||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
if not is_img2img:
|
if not is_img2img:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
|
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
|
||||||
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
|
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
|
||||||
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
|
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))
|
||||||
|
@ -96,7 +96,7 @@ class Script(scripts.Script):
|
|||||||
p.prompt_for_display = positive_prompt
|
p.prompt_for_display = positive_prompt
|
||||||
processed = process_images(p)
|
processed = process_images(p)
|
||||||
|
|
||||||
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
|
||||||
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
|
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
|
||||||
processed.images.insert(0, grid)
|
processed.images.insert(0, grid)
|
||||||
processed.index_of_first_image = 1
|
processed.index_of_first_image = 1
|
||||||
|
@ -109,7 +109,7 @@ class Script(scripts.Script):
|
|||||||
def title(self):
|
def title(self):
|
||||||
return "Prompts from file or textbox"
|
return "Prompts from file or textbox"
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
|
||||||
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
|
||||||
|
|
||||||
@ -166,7 +166,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
proc = process_images(copy_p)
|
proc = process_images(copy_p)
|
||||||
images += proc.images
|
images += proc.images
|
||||||
|
|
||||||
if checkbox_iterate:
|
if checkbox_iterate:
|
||||||
p.seed = p.seed + (p.batch_size * p.n_iter)
|
p.seed = p.seed + (p.batch_size * p.n_iter)
|
||||||
all_prompts += proc.all_prompts
|
all_prompts += proc.all_prompts
|
||||||
|
@ -16,7 +16,7 @@ class Script(scripts.Script):
|
|||||||
def show(self, is_img2img):
|
def show(self, is_img2img):
|
||||||
return is_img2img
|
return is_img2img
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
|
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
|
||||||
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
|
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
|
||||||
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
|
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))
|
||||||
|
@ -1,62 +1,64 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class UtilsTests(unittest.TestCase):
|
class UtilsTests(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.url_options = "http://localhost:7860/sdapi/v1/options"
|
self.url_options = "http://localhost:7860/sdapi/v1/options"
|
||||||
self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
|
self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
|
||||||
self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
|
self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
|
||||||
self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
|
self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
|
||||||
self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
|
self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
|
||||||
self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
|
self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
|
||||||
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
|
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
|
||||||
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
|
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
|
||||||
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
|
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
|
||||||
self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
|
self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
|
||||||
|
|
||||||
def test_options_get(self):
|
def test_options_get(self):
|
||||||
self.assertEqual(requests.get(self.url_options).status_code, 200)
|
self.assertEqual(requests.get(self.url_options).status_code, 200)
|
||||||
|
|
||||||
def test_options_write(self):
|
def test_options_write(self):
|
||||||
response = requests.get(self.url_options)
|
response = requests.get(self.url_options)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
pre_value = response.json()["send_seed"]
|
pre_value = response.json()["send_seed"]
|
||||||
|
|
||||||
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200)
|
self.assertEqual(requests.post(self.url_options, json={"send_seed": not pre_value}).status_code, 200)
|
||||||
|
|
||||||
response = requests.get(self.url_options)
|
response = requests.get(self.url_options)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.json()["send_seed"], not pre_value)
|
self.assertEqual(response.json()["send_seed"], not pre_value)
|
||||||
|
|
||||||
requests.post(self.url_options, json={"send_seed": pre_value})
|
requests.post(self.url_options, json={"send_seed": pre_value})
|
||||||
|
|
||||||
def test_cmd_flags(self):
|
def test_cmd_flags(self):
|
||||||
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
|
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
|
||||||
|
|
||||||
def test_samplers(self):
|
def test_samplers(self):
|
||||||
self.assertEqual(requests.get(self.url_samplers).status_code, 200)
|
self.assertEqual(requests.get(self.url_samplers).status_code, 200)
|
||||||
|
|
||||||
def test_upscalers(self):
|
def test_upscalers(self):
|
||||||
self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
|
self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
|
||||||
|
|
||||||
def test_sd_models(self):
|
def test_sd_models(self):
|
||||||
self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
|
self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
|
||||||
|
|
||||||
def test_hypernetworks(self):
|
def test_hypernetworks(self):
|
||||||
self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
|
self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
|
||||||
|
|
||||||
def test_face_restorers(self):
|
def test_face_restorers(self):
|
||||||
self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
|
self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
|
||||||
|
|
||||||
def test_realesrgan_models(self):
|
def test_realesrgan_models(self):
|
||||||
self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
|
self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
|
||||||
|
|
||||||
def test_prompt_styles(self):
|
def test_prompt_styles(self):
|
||||||
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
|
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
|
||||||
|
|
||||||
|
def test_embeddings(self):
|
||||||
|
self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
|
||||||
|
|
||||||
def test_embeddings(self):
|
|
||||||
self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user