Merge pull request #10285 from akx/ruff-spacing

Indentation + ruff whitespace fixes
This commit is contained in:
AUTOMATIC1111 2023-05-11 21:25:15 +03:00 committed by GitHub
commit abe32cefa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 301 additions and 296 deletions

View File

@ -130,11 +130,11 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else: else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta) logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"] sample = logs["sample"]

View File

@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
self.instantiate_cond_stage(cond_stage_config) self.instantiate_cond_stage(cond_stage_config)
self.cond_stage_forward = cond_stage_forward self.cond_stage_forward = cond_stage_forward
self.clip_denoised = False self.clip_denoised = False
self.bbox_tokenizer = None self.bbox_tokenizer = None
self.restarted_from_ckpt = False self.restarted_from_ckpt = False
if ckpt_path is not None: if ckpt_path is not None:
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
# 2. apply model loop over last dim # 2. apply model loop over last dim
if isinstance(self.first_stage_model, VQModelInterface): if isinstance(self.first_stage_model, VQModelInterface):
output_list = [self.first_stage_model.decode(z[:, :, :, :, i], output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
force_not_quantize=predict_cids or force_not_quantize) force_not_quantize=predict_cids or force_not_quantize)
for i in range(z.shape[-1])] for i in range(z.shape[-1])]
@ -890,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
if hasattr(self, "split_input_params"): if hasattr(self, "split_input_params"):
assert len(cond) == 1 # todo can only deal with one conditioning atm assert len(cond) == 1 # todo can only deal with one conditioning atm
assert not return_ids assert not return_ids
ks = self.split_input_params["ks"] # eg. (128, 128) ks = self.split_input_params["ks"] # eg. (128, 128)
stride = self.split_input_params["stride"] # eg. (64, 64) stride = self.split_input_params["stride"] # eg. (64, 64)

View File

@ -265,4 +265,4 @@ class SCUNet(nn.Module):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)

View File

@ -150,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for w_idx in w_idx_list: for w_idx in w_idx_list:
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
break break
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)

View File

@ -805,7 +805,7 @@ class SwinIR(nn.Module):
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
x = self.check_image_size(x) x = self.check_image_size(x)
self.mean = self.mean.type_as(x) self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range x = (x - self.mean) * self.img_range

View File

@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask) self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size): def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA # calculate attention mask for SW-MSA
H, W = x_size H, W = x_size
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask return attn_mask
def forward(self, x, x_size): def forward(self, x, x_size):
H, W = x_size H, W = x_size
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else: else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows # merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
H, W = self.input_resolution H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2 flops += H * W * self.dim // 2
return flops return flops
class BasicLayer(nn.Module): class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage. """ A basic Swin Transformer layer for one stage.
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
nn.init.constant_(blk.norm1.weight, 0) nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0) nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0) nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
r""" Image to Patch Embedding r""" Image to Patch Embedding
Args: Args:
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None: if self.norm is not None:
flops += Ho * Wo * self.embed_dim flops += Ho * Wo * self.embed_dim
return flops return flops
class RSTB(nn.Module): class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB). """Residual Swin Transformer Block (RSTB).
@ -531,7 +531,7 @@ class RSTB(nn.Module):
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, drop_path=drop_path,
norm_layer=norm_layer, norm_layer=norm_layer,
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m) super(Upsample, self).__init__(*m)
class Upsample_hf(nn.Sequential): class Upsample_hf(nn.Sequential):
"""Upsample module. """Upsample module.
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
m.append(nn.PixelShuffle(3)) m.append(nn.PixelShuffle(3))
else: else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample_hf, self).__init__(*m) super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential): class UpsampleOneStep(nn.Sequential):
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
H, W = self.input_resolution H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9 flops = H * W * self.num_feat * 3 * 9
return flops return flops
class Swin2SR(nn.Module): class Swin2SR(nn.Module):
r""" Swin2SR r""" Swin2SR
@ -699,7 +699,7 @@ class Swin2SR(nn.Module):
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
) )
self.layers.append(layer) self.layers.append(layer)
if self.upsampler == 'pixelshuffle_hf': if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList() self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers): for i_layer in range(self.num_layers):
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer, norm_layer=norm_layer,
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
) )
self.layers_hf.append(layer) self.layers_hf.append(layer)
self.norm = norm_layer(self.num_features) self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction # build the last conv layer in deep feature extraction
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential( self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1), nn.Conv2d(3, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat) self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffle_hf': elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True)) nn.LeakyReLU(inplace=True))
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect': elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters) # for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward_features_hf(self, x): def forward_features_hf(self, x):
x_size = (x.shape[2], x.shape[3]) x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x) x = self.patch_embed(x)
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
x = self.norm(x) # B L C x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size) x = self.patch_unembed(x, x_size)
return x return x
def forward(self, x): def forward(self, x):
H, W = x.shape[2:] H, W = x.shape[2:]
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
x = self.conv_after_body(self.forward_features(x)) + x x = self.conv_after_body(self.forward_features(x)) + x
x_before = self.conv_before_upsample(x) x_before = self.conv_before_upsample(x)
x_out = self.conv_last(self.upsample(x_before)) x_out = self.conv_last(self.upsample(x_before))
x_hf = self.conv_first_hf(x_before) x_hf = self.conv_first_hf(x_before)
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
x_hf = self.conv_before_upsample_hf(x_hf) x_hf = self.conv_before_upsample_hf(x_hf)
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
x_first = self.conv_first(x) x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res) x = x + self.conv_last(res)
x = x / self.img_range + self.mean x = x / self.img_range + self.mean
if self.upsampler == "pixelshuffle_aux": if self.upsampler == "pixelshuffle_aux":
return x[:, :, :H*self.upscale, :W*self.upscale], aux return x[:, :, :H*self.upscale, :W*self.upscale], aux
elif self.upsampler == "pixelshuffle_hf": elif self.upsampler == "pixelshuffle_hf":
x_out = x_out / self.img_range + self.mean x_out = x_out / self.img_range + self.mean
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
else: else:
return x[:, :, :H*self.upscale, :W*self.upscale] return x[:, :, :H*self.upscale, :W*self.upscale]
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
x = torch.randn((1, 3, height, width)) x = torch.randn((1, 3, height, width))
x = model(x) x = model(x)
print(x.shape) print(x.shape)

View File

@ -327,7 +327,7 @@ def prepare_environment():
if args.update_all_extensions: if args.update_all_extensions:
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
exit(0) exit(0)

View File

@ -227,7 +227,7 @@ class Api:
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
script = script_runner.selectable_scripts[script_idx] script = script_runner.selectable_scripts[script_idx]
return script, script_idx return script, script_idx
def get_scripts_list(self): def get_scripts_list(self):
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles] t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles] i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
@ -237,7 +237,7 @@ class Api:
def get_script(self, script_name, script_runner): def get_script(self, script_name, script_runner):
if script_name is None or script_name == "": if script_name is None or script_name == "":
return None, None return None, None
script_idx = script_name_to_index(script_name, script_runner.scripts) script_idx = script_name_to_index(script_name, script_runner.scripts)
return script_runner.scripts[script_idx] return script_runner.scripts[script_idx]

View File

@ -289,4 +289,4 @@ class MemoryResponse(BaseModel):
class ScriptsList(BaseModel): class ScriptsList(BaseModel):
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)") txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)") img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")

View File

@ -102,4 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers") parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False) parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')

View File

@ -119,7 +119,7 @@ class TransformerSALayer(nn.Module):
tgt_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None): query_pos: Optional[Tensor] = None):
# self attention # self attention
tgt2 = self.norm1(tgt) tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos) q = k = self.with_pos_embed(tgt2, query_pos)
@ -159,7 +159,7 @@ class Fuse_sft_block(nn.Module):
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=('32', '64', '128', '256'), connect_list=('32', '64', '128', '256'),
fix_modules=('quantize', 'generator')): fix_modules=('quantize', 'generator')):
@ -179,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
self.feat_emb = nn.Linear(256, self.dim_embd) self.feat_emb = nn.Linear(256, self.dim_embd)
# transformer # transformer
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
for _ in range(self.n_layers)]) for _ in range(self.n_layers)])
# logits_predict head # logits_predict head
self.idx_pred_layer = nn.Sequential( self.idx_pred_layer = nn.Sequential(
nn.LayerNorm(dim_embd), nn.LayerNorm(dim_embd),
nn.Linear(dim_embd, codebook_size, bias=False)) nn.Linear(dim_embd, codebook_size, bias=False))
self.channels = { self.channels = {
'16': 512, '16': 512,
'32': 256, '32': 256,
@ -221,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
enc_feat_dict = {} enc_feat_dict = {}
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.encoder.blocks): for i, block in enumerate(self.encoder.blocks):
x = block(x) x = block(x)
if i in out_list: if i in out_list:
enc_feat_dict[str(x.shape[-1])] = x.clone() enc_feat_dict[str(x.shape[-1])] = x.clone()
@ -266,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
for i, block in enumerate(self.generator.blocks): for i, block in enumerate(self.generator.blocks):
x = block(x) x = block(x)
if i in fuse_list: # fuse after i-th block if i in fuse_list: # fuse after i-th block
f_size = str(x.shape[-1]) f_size = str(x.shape[-1])
if w>0: if w>0:
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x out = x
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat return out, logits, lq_feat

View File

@ -13,7 +13,7 @@ from basicsr.utils.registry import ARCH_REGISTRY
def normalize(in_channels): def normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
@torch.jit.script @torch.jit.script
def swish(x): def swish(x):
@ -210,15 +210,15 @@ class AttnBlock(nn.Module):
# compute attention # compute attention
b, c, h, w = q.shape b, c, h, w = q.shape
q = q.reshape(b, c, h*w) q = q.reshape(b, c, h*w)
q = q.permute(0, 2, 1) q = q.permute(0, 2, 1)
k = k.reshape(b, c, h*w) k = k.reshape(b, c, h*w)
w_ = torch.bmm(q, k) w_ = torch.bmm(q, k)
w_ = w_ * (int(c)**(-0.5)) w_ = w_ * (int(c)**(-0.5))
w_ = F.softmax(w_, dim=2) w_ = F.softmax(w_, dim=2)
# attend to values # attend to values
v = v.reshape(b, c, h*w) v = v.reshape(b, c, h*w)
w_ = w_.permute(0, 2, 1) w_ = w_.permute(0, 2, 1)
h_ = torch.bmm(v, w_) h_ = torch.bmm(v, w_)
h_ = h_.reshape(b, c, h, w) h_ = h_.reshape(b, c, h, w)
@ -270,18 +270,18 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
return x return x
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
super().__init__() super().__init__()
self.nf = nf self.nf = nf
self.ch_mult = ch_mult self.ch_mult = ch_mult
self.num_resolutions = len(self.ch_mult) self.num_resolutions = len(self.ch_mult)
self.num_res_blocks = res_blocks self.num_res_blocks = res_blocks
self.resolution = img_size self.resolution = img_size
self.attn_resolutions = attn_resolutions self.attn_resolutions = attn_resolutions
self.in_channels = emb_dim self.in_channels = emb_dim
self.out_channels = 3 self.out_channels = 3
@ -315,24 +315,24 @@ class Generator(nn.Module):
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
self.blocks = nn.ModuleList(blocks) self.blocks = nn.ModuleList(blocks)
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
return x return x
@ARCH_REGISTRY.register() @ARCH_REGISTRY.register()
class VQAutoEncoder(nn.Module): class VQAutoEncoder(nn.Module):
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
super().__init__() super().__init__()
logger = get_root_logger() logger = get_root_logger()
self.in_channels = 3 self.in_channels = 3
self.nf = nf self.nf = nf
self.n_blocks = res_blocks self.n_blocks = res_blocks
self.codebook_size = codebook_size self.codebook_size = codebook_size
self.embed_dim = emb_dim self.embed_dim = emb_dim
self.ch_mult = ch_mult self.ch_mult = ch_mult
@ -363,11 +363,11 @@ class VQAutoEncoder(nn.Module):
self.kl_weight self.kl_weight
) )
self.generator = Generator( self.generator = Generator(
self.nf, self.nf,
self.embed_dim, self.embed_dim,
self.ch_mult, self.ch_mult,
self.n_blocks, self.n_blocks,
self.resolution, self.resolution,
self.attn_resolutions self.attn_resolutions
) )
@ -432,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
raise ValueError('Wrong params!') raise ValueError('Wrong params!')
def forward(self, x): def forward(self, x):
return self.main(x) return self.main(x)

View File

@ -105,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
Modified options that can be used: Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718 - "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957 - "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo} {Rakotonirina} and A. {Rasoanaivo}
""" """
@ -170,7 +170,7 @@ class GaussianNoise(nn.Module):
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
x = x + sampled_noise x = x + sampled_noise
return x return x
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

View File

@ -199,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
result_is_inpainting_model = True result_is_inpainting_model = True
else: else:
theta_0[key] = theta_func2(a, b, multiplier) theta_0[key] = theta_func2(a, b, multiplier)
theta_0[key] = to_half(theta_0[key], save_as_half) theta_0[key] = to_half(theta_0[key], save_as_half)
shared.state.sampling_step += 1 shared.state.sampling_step += 1

View File

@ -540,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad: if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@ -593,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
print(e) print(e)
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size batch_size = ds.batch_size
gradient_step = ds.gradient_step gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed # n steps = batch_size * gradient_step * n image processed
@ -636,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
if clip_grad: if clip_grad:
clip_grad_sched.step(hypernetwork.step) clip_grad_sched.step(hypernetwork.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight: if use_weight:
@ -657,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step += loss.item() _loss_step += loss.item()
scaler.scale(loss).backward() scaler.scale(loss).backward()
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
loss_logging.append(_loss_step) loss_logging.append(_loss_step)
if clip_grad: if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate) clip_grad(weights, clip_grad_sched.learn_rate)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
hypernetwork.step += 1 hypernetwork.step += 1
@ -674,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
_loss_step = 0 _loss_step = 0
steps_done = hypernetwork.step + 1 steps_done = hypernetwork.step + 1
epoch_num = hypernetwork.step // steps_per_epoch epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch

View File

@ -367,7 +367,7 @@ class FilenameGenerator:
self.seed = seed self.seed = seed
self.prompt = prompt self.prompt = prompt
self.image = image self.image = image
def hasprompt(self, *args): def hasprompt(self, *args):
lower = self.prompt.lower() lower = self.prompt.lower()
if self.p is None or self.prompt is None: if self.p is None or self.prompt is None:

View File

@ -42,7 +42,7 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 # MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 # MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 # MPS workaround for https://github.com/pytorch/pytorch/issues/90532
@ -60,4 +60,4 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311 # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386': if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')

View File

@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
def get_crop_region(mask, pad=0): def get_crop_region(mask, pad=0):
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
h, w = mask.shape h, w = mask.shape
crop_left = 0 crop_left = 0

View File

@ -13,7 +13,7 @@ def connect(token, port, region):
config = conf.PyngrokConfig( config = conf.PyngrokConfig(
auth_token=token, region=region auth_token=token, region=region
) )
# Guard for existing tunnels # Guard for existing tunnels
existing = ngrok.get_tunnels(pyngrok_config=config) existing = ngrok.get_tunnels(pyngrok_config=config)
if existing: if existing:
@ -24,7 +24,7 @@ def connect(token, port, region):
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n' print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
'You can use this link after the launch is complete.') 'You can use this link after the launch is complete.')
return return
try: try:
if account is None: if account is None:
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url

View File

@ -164,7 +164,7 @@ class StableDiffusionProcessing:
self.all_subseeds = None self.all_subseeds = None
self.iteration = 0 self.iteration = 0
self.is_hr_pass = False self.is_hr_pass = False
@property @property
def sd_model(self): def sd_model(self):

View File

@ -32,22 +32,22 @@ class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
self.x = x self.x = x
"""Latent image representation in the process of being denoised""" """Latent image representation in the process of being denoised"""
self.image_cond = image_cond self.image_cond = image_cond
"""Conditioning image""" """Conditioning image"""
self.sigma = sigma self.sigma = sigma
"""Current sigma noise step value""" """Current sigma noise step value"""
self.sampling_step = sampling_step self.sampling_step = sampling_step
"""Current Sampling step number""" """Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned""" """Total number of sampling steps planned"""
self.text_cond = text_cond self.text_cond = text_cond
""" Encoder hidden states of text conditioning from prompt""" """ Encoder hidden states of text conditioning from prompt"""
self.text_uncond = text_uncond self.text_uncond = text_uncond
""" Encoder hidden states of text conditioning from negative prompt""" """ Encoder hidden states of text conditioning from negative prompt"""
@ -240,7 +240,7 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks(): def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'

View File

@ -34,7 +34,7 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
optimization_method = None optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
@ -92,12 +92,12 @@ def fix_checkpoint():
def weighted_loss(sd_model, pred, target, mean=True): def weighted_loss(sd_model, pred, target, mean=True):
#Calculate the weight normally, but ignore the mean #Calculate the weight normally, but ignore the mean
loss = sd_model._old_get_loss(pred, target, mean=False) loss = sd_model._old_get_loss(pred, target, mean=False)
#Check if we have weights available #Check if we have weights available
weight = getattr(sd_model, '_custom_loss_weight', None) weight = getattr(sd_model, '_custom_loss_weight', None)
if weight is not None: if weight is not None:
loss *= weight loss *= weight
#Return the loss, as mean if specified #Return the loss, as mean if specified
return loss.mean() if mean else loss return loss.mean() if mean else loss
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
try: try:
#Temporarily append weights to a place accessible during loss calc #Temporarily append weights to a place accessible during loss calc
sd_model._custom_loss_weight = w sd_model._custom_loss_weight = w
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
if not hasattr(sd_model, '_old_get_loss'): if not hasattr(sd_model, '_old_get_loss'):
@ -120,7 +120,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
del sd_model._custom_loss_weight del sd_model._custom_loss_weight
except AttributeError: except AttributeError:
pass pass
#If we have an old loss function, reset the loss function to the original one #If we have an old loss function, reset the loss function to the original one
if hasattr(sd_model, '_old_get_loss'): if hasattr(sd_model, '_old_get_loss'):
sd_model.get_loss = sd_model._old_get_loss sd_model.get_loss = sd_model._old_get_loss
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped

View File

@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
end = i + 2 end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale s1 *= self.scale
s2 = s1.softmax(dim=-1) s2 = s1.softmax(dim=-1)
del s1 del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2 del s2
del q, k, v del q, k, v
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn): with devices.without_autocast(disable=not shared.opts.upcast_attn):
k_in = k_in * self.scale k_in = k_in * self.scale
del context, x del context, x
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
mem_free_total = get_available_vram() mem_free_total = get_available_vram()
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
steps = 1 steps = 1
if mem_required > mem_free_total: if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64: if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype) s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1 del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
del q, k, v del q, k, v
r1 = r1.to(dtype) r1 = r1.to(dtype)
@ -228,7 +228,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn): with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale k = k * self.scale
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v) r = einsum_op(q, k, v)
r = r.to(dtype) r = r.to(dtype)
@ -369,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in del q_in, k_in, v_in
dtype = q.dtype dtype = q.dtype
@ -451,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
h3 += x h3 += x
return h3 return h3
def xformers_attnblock_forward(self, x): def xformers_attnblock_forward(self, x):
try: try:
h_ = x h_ = x

View File

@ -165,7 +165,7 @@ def model_hash(filename):
def select_checkpoint(): def select_checkpoint():
model_checkpoint = shared.opts.sd_model_checkpoint model_checkpoint = shared.opts.sd_model_checkpoint
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
if checkpoint_info is not None: if checkpoint_info is not None:
return checkpoint_info return checkpoint_info
@ -372,7 +372,7 @@ def enable_midas_autodownload():
if not os.path.exists(path): if not os.path.exists(path):
if not os.path.exists(midas_path): if not os.path.exists(midas_path):
mkdir(midas_path) mkdir(midas_path)
print(f"Downloading midas model weights for {model_type} to {path}") print(f"Downloading midas model weights for {model_type} to {path}")
request.urlretrieve(midas_urls[model_type], path) request.urlretrieve(midas_urls[model_type], path)
print(f"{model_type} downloaded") print(f"{model_type} downloaded")

View File

@ -93,10 +93,10 @@ class CFGDenoiser(torch.nn.Module):
if shared.sd_model.model.conditioning_key == "crossattn-adm": if shared.sd_model.model.conditioning_key == "crossattn-adm":
image_uncond = torch.zeros_like(image_cond) image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}
else: else:
image_uncond = image_cond image_uncond = image_cond
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]} make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": [c_concat]}
if not is_edit_model: if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
@ -316,7 +316,7 @@ class KDiffusionSampler:
sigma_sched = sigmas[steps - t_enc - 1:] sigma_sched = sigmas[steps - t_enc - 1:]
xi = x + noise * sigma_sched[0] xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
parameters = inspect.signature(self.func).parameters parameters = inspect.signature(self.func).parameters
@ -339,9 +339,9 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
extra_args={ extra_args={
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
} }
@ -374,9 +374,9 @@ class KDiffusionSampler:
self.last_latent = x self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }, disable=False, callback=self.callback_state, **extra_params_kwargs))

View File

@ -179,7 +179,7 @@ def efficient_dot_product_attention(
chunk_idx, chunk_idx,
min(query_chunk_size, q_tokens) min(query_chunk_size, q_tokens)
) )
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial( compute_query_chunk_attn: ComputeQueryChunkAttn = partial(

View File

@ -10,63 +10,64 @@ RED = "#F00"
def crop_image(im, settings): def crop_image(im, settings):
""" Intelligently crop an image to the subject matter """ """ Intelligently crop an image to the subject matter """
scale_by = 1 scale_by = 1
if is_landscape(im.width, im.height): if is_landscape(im.width, im.height):
scale_by = settings.crop_height / im.height scale_by = settings.crop_height / im.height
elif is_portrait(im.width, im.height): elif is_portrait(im.width, im.height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_square(im.width, im.height): elif is_square(im.width, im.height):
if is_square(settings.crop_width, settings.crop_height): if is_square(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_landscape(settings.crop_width, settings.crop_height): elif is_landscape(settings.crop_width, settings.crop_height):
scale_by = settings.crop_width / im.width scale_by = settings.crop_width / im.width
elif is_portrait(settings.crop_width, settings.crop_height): elif is_portrait(settings.crop_width, settings.crop_height):
scale_by = settings.crop_height / im.height scale_by = settings.crop_height / im.height
im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
focus = focal_point(im_debug, settings) im = im.resize((int(im.width * scale_by), int(im.height * scale_by)))
im_debug = im.copy()
# take the focal point and turn it into crop coordinates that try to center over the focal focus = focal_point(im_debug, settings)
# point but then get adjusted back into the frame
y_half = int(settings.crop_height / 2)
x_half = int(settings.crop_width / 2)
x1 = focus.x - x_half # take the focal point and turn it into crop coordinates that try to center over the focal
if x1 < 0: # point but then get adjusted back into the frame
x1 = 0 y_half = int(settings.crop_height / 2)
elif x1 + settings.crop_width > im.width: x_half = int(settings.crop_width / 2)
x1 = im.width - settings.crop_width
y1 = focus.y - y_half x1 = focus.x - x_half
if y1 < 0: if x1 < 0:
y1 = 0 x1 = 0
elif y1 + settings.crop_height > im.height: elif x1 + settings.crop_width > im.width:
y1 = im.height - settings.crop_height x1 = im.width - settings.crop_width
x2 = x1 + settings.crop_width y1 = focus.y - y_half
y2 = y1 + settings.crop_height if y1 < 0:
y1 = 0
elif y1 + settings.crop_height > im.height:
y1 = im.height - settings.crop_height
crop = [x1, y1, x2, y2] x2 = x1 + settings.crop_width
y2 = y1 + settings.crop_height
results = [] crop = [x1, y1, x2, y2]
results.append(im.crop(tuple(crop))) results = []
if settings.annotate_image: results.append(im.crop(tuple(crop)))
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results if settings.annotate_image:
d = ImageDraw.Draw(im_debug)
rect = list(crop)
rect[2] -= 1
rect[3] -= 1
d.rectangle(rect, outline=GREEN)
results.append(im_debug)
if settings.destop_view_image:
im_debug.show()
return results
def focal_point(im, settings): def focal_point(im, settings):
corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else [] corner_points = image_corner_points(im, settings) if settings.corner_points_weight > 0 else []
@ -86,7 +87,7 @@ def focal_point(im, settings):
corner_centroid = None corner_centroid = None
if len(corner_points) > 0: if len(corner_points) > 0:
corner_centroid = centroid(corner_points) corner_centroid = centroid(corner_points)
corner_centroid.weight = settings.corner_points_weight / weight_pref_total corner_centroid.weight = settings.corner_points_weight / weight_pref_total
pois.append(corner_centroid) pois.append(corner_centroid)
entropy_centroid = None entropy_centroid = None
@ -98,7 +99,7 @@ def focal_point(im, settings):
face_centroid = None face_centroid = None
if len(face_points) > 0: if len(face_points) > 0:
face_centroid = centroid(face_points) face_centroid = centroid(face_points)
face_centroid.weight = settings.face_points_weight / weight_pref_total face_centroid.weight = settings.face_points_weight / weight_pref_total
pois.append(face_centroid) pois.append(face_centroid)
average_point = poi_average(pois, settings) average_point = poi_average(pois, settings)
@ -132,7 +133,7 @@ def focal_point(im, settings):
d.rectangle(f.bounding(4), outline=color) d.rectangle(f.bounding(4), outline=color)
d.ellipse(average_point.bounding(max_size), outline=GREEN) d.ellipse(average_point.bounding(max_size), outline=GREEN)
return average_point return average_point
@ -260,10 +261,11 @@ def image_entropy(im):
hist = hist[hist > 0] hist = hist[hist > 0]
return -np.log2(hist / hist.sum()).sum() return -np.log2(hist / hist.sum()).sum()
def centroid(pois): def centroid(pois):
x = [poi.x for poi in pois] x = [poi.x for poi in pois]
y = [poi.y for poi in pois] y = [poi.y for poi in pois]
return PointOfInterest(sum(x)/len(pois), sum(y)/len(pois)) return PointOfInterest(sum(x) / len(pois), sum(y) / len(pois))
def poi_average(pois, settings): def poi_average(pois, settings):
@ -281,59 +283,59 @@ def poi_average(pois, settings):
def is_landscape(w, h): def is_landscape(w, h):
return w > h return w > h
def is_portrait(w, h): def is_portrait(w, h):
return h > w return h > w
def is_square(w, h): def is_square(w, h):
return w == h return w == h
def download_and_cache_models(dirname): def download_and_cache_models(dirname):
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
model_file_name = 'face_detection_yunet.onnx' model_file_name = 'face_detection_yunet.onnx'
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
cache_file = os.path.join(dirname, model_file_name) cache_file = os.path.join(dirname, model_file_name)
if not os.path.exists(cache_file): if not os.path.exists(cache_file):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'") print(f"downloading face detection model from '{download_url}' to '{cache_file}'")
response = requests.get(download_url) response = requests.get(download_url)
with open(cache_file, "wb") as f: with open(cache_file, "wb") as f:
f.write(response.content) f.write(response.content)
if os.path.exists(cache_file): if os.path.exists(cache_file):
return cache_file return cache_file
return None return None
class PointOfInterest: class PointOfInterest:
def __init__(self, x, y, weight=1.0, size=10): def __init__(self, x, y, weight=1.0, size=10):
self.x = x self.x = x
self.y = y self.y = y
self.weight = weight self.weight = weight
self.size = size self.size = size
def bounding(self, size): def bounding(self, size):
return [ return [
self.x - size//2, self.x - size // 2,
self.y - size//2, self.y - size // 2,
self.x + size//2, self.x + size // 2,
self.y + size//2 self.y + size // 2
] ]
class Settings: class Settings:
def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None): def __init__(self, crop_width=512, crop_height=512, corner_points_weight=0.5, entropy_points_weight=0.5, face_points_weight=0.5, annotate_image=False, dnn_model_path=None):
self.crop_width = crop_width self.crop_width = crop_width
self.crop_height = crop_height self.crop_height = crop_height
self.corner_points_weight = corner_points_weight self.corner_points_weight = corner_points_weight
self.entropy_points_weight = entropy_points_weight self.entropy_points_weight = entropy_points_weight
self.face_points_weight = face_points_weight self.face_points_weight = face_points_weight
self.annotate_image = annotate_image self.annotate_image = annotate_image
self.destop_view_image = False self.destop_view_image = False
self.dnn_model_path = dnn_model_path self.dnn_model_path = dnn_model_path

View File

@ -118,7 +118,7 @@ class PersonalizedBase(Dataset):
weight = torch.ones(latent_sample.shape) weight = torch.ones(latent_sample.shape)
else: else:
weight = None weight = None
if latent_sampling_method == "random": if latent_sampling_method == "random":
entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight) entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, weight=weight)
else: else:
@ -243,4 +243,4 @@ class BatchLoaderRandom(BatchLoader):
return self return self
def collate_wrapper_random(batch): def collate_wrapper_random(batch):
return BatchLoaderRandom(batch) return BatchLoaderRandom(batch)

View File

@ -125,7 +125,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
default=None default=None
) )
return wh and center_crop(image, *wh) return wh and center_crop(image, *wh)
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
width = process_width width = process_width

View File

@ -323,16 +323,16 @@ def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epo
tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
def tensorboard_add_scaler(tensorboard_writer, tag, value, step): def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
tensorboard_writer.add_scalar(tag=tag, tensorboard_writer.add_scalar(tag=tag,
scalar_value=value, global_step=step) scalar_value=value, global_step=step)
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
# Convert a pil image to a torch tensor # Convert a pil image to a torch tensor
img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0],
len(pil_image.getbands())) len(pil_image.getbands()))
img_tensor = img_tensor.permute((2, 0, 1)) img_tensor = img_tensor.permute((2, 0, 1))
tensorboard_writer.add_image(tag, img_tensor, global_step=step) tensorboard_writer.add_image(tag, img_tensor, global_step=step)
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
@ -402,7 +402,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if initial_step >= steps: if initial_step >= steps:
shared.state.textinfo = "Model has already been trained beyond specified max steps" shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step) scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
@ -412,7 +412,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# dataset loading may take a while, so input validations and early returns should be done before this # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed old_parallel_processing_allowed = shared.parallel_processing_allowed
if shared.opts.training_enable_tensorboard: if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory) tensorboard_writer = tensorboard_setup(log_directory)
@ -439,7 +439,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None): if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
if optimizer_state_dict is not None: if optimizer_state_dict is not None:
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
print("Loaded existing optimizer from checkpoint") print("Loaded existing optimizer from checkpoint")
@ -485,7 +485,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
if clip_grad: if clip_grad:
clip_grad_sched.step(embedding.step) clip_grad_sched.step(embedding.step)
with devices.autocast(): with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight: if use_weight:
@ -513,7 +513,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
# go back until we reach gradient accumulation steps # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0: if (j + 1) % gradient_step != 0:
continue continue
if clip_grad: if clip_grad:
clip_grad(embedding.vec, clip_grad_sched.learn_rate) clip_grad(embedding.vec, clip_grad_sched.learn_rate)

View File

@ -1171,7 +1171,7 @@ def create_ui():
process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight")
process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight")
process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug")
with gr.Column(visible=False) as process_multicrop_col: with gr.Column(visible=False) as process_multicrop_col:
gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') gr.Markdown('Each image is center-cropped with an automatically chosen width and height.')
with gr.Row(): with gr.Row():
@ -1183,7 +1183,7 @@ def create_ui():
with gr.Row(): with gr.Row():
process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective")
process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold")
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
gr.HTML(value="") gr.HTML(value="")
@ -1226,7 +1226,7 @@ def create_ui():
with FormRow(): with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
with FormRow(): with FormRow():
clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)
@ -1565,7 +1565,7 @@ def create_ui():
gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.HTML(shared.html("licenses.html"), elem_id="licenses")
gr.Button(value="Show all pages", elem_id="settings_show_all_pages") gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
def unload_sd_weights(): def unload_sd_weights():
modules.sd_models.unload_model_weights() modules.sd_models.unload_model_weights()
@ -1841,15 +1841,15 @@ def versions_html():
return f""" return f"""
version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a> version: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/commit/{commit}">{tag}</a>
    
python: <span title="{sys.version}">{python_version}</span> python: <span title="{sys.version}">{python_version}</span>
    
torch: {getattr(torch, '__long_version__',torch.__version__)} torch: {getattr(torch, '__long_version__',torch.__version__)}
    
xformers: {xformers_version} xformers: {xformers_version}
    
gradio: {gr.__version__} gradio: {gr.__version__}
    
checkpoint: <a id="sd_checkpoint_hash">N/A</a> checkpoint: <a id="sd_checkpoint_hash">N/A</a>
""" """

View File

@ -467,7 +467,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
<td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td> <td>{html.escape(description)}<p class="info"><span class="date_added">Added: {html.escape(added)}</span></p></td>
<td>{install_code}</td> <td>{install_code}</td>
</tr> </tr>
""" """
for tag in [x for x in extension_tags if x not in tags]: for tag in [x for x in extension_tags if x not in tags]:
@ -535,9 +535,9 @@ def create_ui():
hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index")
with gr.Row(): with gr.Row():
search_extensions_text = gr.Text(label="Search").style(container=False) search_extensions_text = gr.Text(label="Search").style(container=False)
install_result = gr.HTML() install_result = gr.HTML()
available_extensions_table = gr.HTML() available_extensions_table = gr.HTML()

View File

@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
config_class = BertSeriesConfig config_class = BertSeriesConfig
def __init__(self, config=None, **kargs): def __init__(self, config=None, **kargs):
# modify initialization for autoloading # modify initialization for autoloading
if config is None: if config is None:
config = XLMRobertaConfig() config = XLMRobertaConfig()
config.attention_probs_dropout_prob= 0.1 config.attention_probs_dropout_prob= 0.1
@ -74,7 +74,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
text["attention_mask"] = torch.tensor( text["attention_mask"] = torch.tensor(
text['attention_mask']).to(device) text['attention_mask']).to(device)
features = self(**text) features = self(**text)
return features['projection_state'] return features['projection_state']
def forward( def forward(
self, self,
@ -134,4 +134,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel):
class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation):
base_model_prefix = 'roberta' base_model_prefix = 'roberta'
config_class= RobertaSeriesConfig config_class= RobertaSeriesConfig

View File

@ -6,6 +6,7 @@ extend-select = [
"B", "B",
"C", "C",
"I", "I",
"W",
] ]
exclude = [ exclude = [
@ -20,7 +21,7 @@ ignore = [
"I001", # Import block is un-sorted or un-formatted "I001", # Import block is un-sorted or un-formatted
"C901", # Function is too complex "C901", # Function is too complex
"C408", # Rewrite as a literal "C408", # Rewrite as a literal
"W605", # invalid escape sequence, messes with some docstrings
] ]
[tool.ruff.per-file-ignores] [tool.ruff.per-file-ignores]
@ -28,4 +29,4 @@ ignore = [
[tool.ruff.flake8-bugbear] [tool.ruff.flake8-bugbear]
# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"]

View File

@ -149,9 +149,9 @@ class Script(scripts.Script):
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment"))
return [ return [
info, info,
override_sampler, override_sampler,
override_prompt, original_prompt, original_negative_prompt, override_prompt, original_prompt, original_negative_prompt,
override_steps, st, override_steps, st,
override_strength, override_strength,
cfg, randomness, sigma_adjustment, cfg, randomness, sigma_adjustment,
@ -191,17 +191,17 @@ class Script(scripts.Script):
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment) self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
sigmas = sampler.model_wrap.get_sigmas(p.steps) sigmas = sampler.model_wrap.get_sigmas(p.steps)
noise_dt = combined_noise - (p.init_latent / sigmas[0]) noise_dt = combined_noise - (p.init_latent / sigmas[0])
p.seed = p.seed + 1 p.seed = p.seed + 1
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
p.sample = sample_extra p.sample = sample_extra

View File

@ -14,7 +14,7 @@ class Script(scripts.Script):
def show(self, is_img2img): def show(self, is_img2img):
return is_img2img return is_img2img
def ui(self, is_img2img): def ui(self, is_img2img):
loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops"))
final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength")) final_denoising_strength = gr.Slider(minimum=0, maximum=1, step=0.01, label='Final denoising strength', value=0.5, elem_id=self.elem_id("final_denoising_strength"))
denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear") denoising_curve = gr.Dropdown(label="Denoising strength curve", choices=["Aggressive", "Linear", "Lazy"], value="Linear")
@ -104,7 +104,7 @@ class Script(scripts.Script):
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = calculate_denoising_strength(i + 1) p.denoising_strength = calculate_denoising_strength(i + 1)
if state.skipped: if state.skipped:
break break
@ -121,7 +121,7 @@ class Script(scripts.Script):
all_images.append(last_image) all_images.append(last_image)
p.inpainting_fill = original_inpainting_fill p.inpainting_fill = original_inpainting_fill
if state.interrupted: if state.interrupted:
break break
@ -132,7 +132,7 @@ class Script(scripts.Script):
if opts.return_grid: if opts.return_grid:
grids.append(grid) grids.append(grid)
all_images = grids + all_images all_images = grids + all_images
processed = Processed(p, all_images, initial_seed, initial_info) processed = Processed(p, all_images, initial_seed, initial_info)

View File

@ -19,7 +19,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
if not is_img2img: if not is_img2img:
return None return None
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels"))
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur"))
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill"))

View File

@ -96,7 +96,7 @@ class Script(scripts.Script):
p.prompt_for_display = positive_prompt p.prompt_for_display = positive_prompt
processed = process_images(p) processed = process_images(p)
grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2)) grid = images.image_grid(processed.images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size) grid = images.draw_prompt_matrix(grid, processed.images[0].width, processed.images[0].height, prompt_matrix_parts, margin_size)
processed.images.insert(0, grid) processed.images.insert(0, grid)
processed.index_of_first_image = 1 processed.index_of_first_image = 1

View File

@ -109,7 +109,7 @@ class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
def ui(self, is_img2img): def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
@ -166,7 +166,7 @@ class Script(scripts.Script):
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images
if checkbox_iterate: if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter) p.seed = p.seed + (p.batch_size * p.n_iter)
all_prompts += proc.all_prompts all_prompts += proc.all_prompts

View File

@ -16,7 +16,7 @@ class Script(scripts.Script):
def show(self, is_img2img): def show(self, is_img2img):
return is_img2img return is_img2img
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image by the selected scale factor; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap"))
scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor"))

View File

@ -1,62 +1,64 @@
import unittest import unittest
import requests import requests
class UtilsTests(unittest.TestCase): class UtilsTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.url_options = "http://localhost:7860/sdapi/v1/options" self.url_options = "http://localhost:7860/sdapi/v1/options"
self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags" self.url_cmd_flags = "http://localhost:7860/sdapi/v1/cmd-flags"
self.url_samplers = "http://localhost:7860/sdapi/v1/samplers" self.url_samplers = "http://localhost:7860/sdapi/v1/samplers"
self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers" self.url_upscalers = "http://localhost:7860/sdapi/v1/upscalers"
self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models" self.url_sd_models = "http://localhost:7860/sdapi/v1/sd-models"
self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks" self.url_hypernetworks = "http://localhost:7860/sdapi/v1/hypernetworks"
self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers"
self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models"
self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles"
self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings"
def test_options_get(self): def test_options_get(self):
self.assertEqual(requests.get(self.url_options).status_code, 200) self.assertEqual(requests.get(self.url_options).status_code, 200)
def test_options_write(self): def test_options_write(self):
response = requests.get(self.url_options) response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
pre_value = response.json()["send_seed"] pre_value = response.json()["send_seed"]
self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) self.assertEqual(requests.post(self.url_options, json={"send_seed": not pre_value}).status_code, 200)
response = requests.get(self.url_options) response = requests.get(self.url_options)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.json()["send_seed"], not pre_value) self.assertEqual(response.json()["send_seed"], not pre_value)
requests.post(self.url_options, json={"send_seed": pre_value}) requests.post(self.url_options, json={"send_seed": pre_value})
def test_cmd_flags(self): def test_cmd_flags(self):
self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200)
def test_samplers(self): def test_samplers(self):
self.assertEqual(requests.get(self.url_samplers).status_code, 200) self.assertEqual(requests.get(self.url_samplers).status_code, 200)
def test_upscalers(self): def test_upscalers(self):
self.assertEqual(requests.get(self.url_upscalers).status_code, 200) self.assertEqual(requests.get(self.url_upscalers).status_code, 200)
def test_sd_models(self): def test_sd_models(self):
self.assertEqual(requests.get(self.url_sd_models).status_code, 200) self.assertEqual(requests.get(self.url_sd_models).status_code, 200)
def test_hypernetworks(self): def test_hypernetworks(self):
self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200) self.assertEqual(requests.get(self.url_hypernetworks).status_code, 200)
def test_face_restorers(self): def test_face_restorers(self):
self.assertEqual(requests.get(self.url_face_restorers).status_code, 200) self.assertEqual(requests.get(self.url_face_restorers).status_code, 200)
def test_realesrgan_models(self): def test_realesrgan_models(self):
self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200) self.assertEqual(requests.get(self.url_realesrgan_models).status_code, 200)
def test_prompt_styles(self): def test_prompt_styles(self):
self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200)
def test_embeddings(self):
self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
def test_embeddings(self):
self.assertEqual(requests.get(self.url_embeddings).status_code, 200)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()