suggestions and fixes from the PR

This commit is contained in:
AUTOMATIC 2023-05-10 21:21:32 +03:00
parent d25219b7e8
commit 3ec7b705c7
11 changed files with 16 additions and 31 deletions

View File

@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
})) }))

View File

@ -644,17 +644,13 @@ class SwinIR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=None, num_heads=None, embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs): **kwargs):
super(SwinIR, self).__init__() super(SwinIR, self).__init__()
depths = depths or [6, 6, 6, 6]
num_heads = num_heads or [6, 6, 6, 6]
num_in_ch = in_chans num_in_ch = in_chans
num_out_ch = in_chans num_out_ch = in_chans
num_feat = 64 num_feat = 64

View File

@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=None): pretrained_window_size=(0, 0)):
super().__init__() super().__init__()
pretrained_window_size = pretrained_window_size or [0, 0]
self.dim = dim self.dim = dim
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size self.pretrained_window_size = pretrained_window_size
@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
""" """
def __init__(self, img_size=64, patch_size=1, in_chans=3, def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=None, num_heads=None, embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
window_size=7, mlp_ratio=4., qkv_bias=True, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs): **kwargs):
super(Swin2SR, self).__init__() super(Swin2SR, self).__init__()
depths = depths or [6, 6, 6, 6]
num_heads = num_heads or [6, 6, 6, 6]
num_in_ch = in_chans num_in_ch = in_chans
num_out_ch = in_chans num_out_ch = in_chans
num_feat = 64 num_feat = 64

View File

@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module):
class CodeFormer(VQAutoEncoder): class CodeFormer(VQAutoEncoder):
def __init__(self, dim_embd=512, n_head=8, n_layers=9, def __init__(self, dim_embd=512, n_head=8, n_layers=9,
codebook_size=1024, latent_size=256, codebook_size=1024, latent_size=256,
connect_list=None, connect_list=('32', '64', '128', '256'),
fix_modules=None): fix_modules=('quantize', 'generator')):
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
connect_list = connect_list or ['32', '64', '128', '256']
fix_modules = fix_modules or ['quantize', 'generator']
if fix_modules is not None: if fix_modules is not None:
for module in fix_modules: for module in fix_modules:
for param in getattr(self, module).parameters(): for param in getattr(self, module).parameters():

View File

@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available] keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", "" return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
def train_hypernetwork(*args): def train_hypernetwork(*args):

View File

@ -275,8 +275,8 @@ def model_wrapper(
A noise prediction model that accepts the noised data and the continuous time as the inputs. A noise prediction model that accepts the noised data and the continuous time as the inputs.
""" """
model_kwargs = model_kwargs or [] model_kwargs = model_kwargs or {}
classifier_kwargs = classifier_kwargs or [] classifier_kwargs = classifier_kwargs or {}
def get_model_input_time(t_continuous): def get_model_input_time(t_continuous):
""" """

View File

@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
script_args = args[script.args_from:script.args_to] script_args = args[script.args_from:script.args_to]
process_args = {} process_args = {}
for (name, component), value in zip(script.controls.items(), script_args): # noqa B007 for (name, _component), value in zip(script.controls.items(), script_args):
process_args[name] = value process_args[name] = value
script.process(pp, **process_args) script.process(pp, **process_args)

View File

@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
self.hijack.fixes = [x.fixes for x in batch_chunk] self.hijack.fixes = [x.fixes for x in batch_chunk]
for fixes in self.hijack.fixes: for fixes in self.hijack.fixes:
for position, embedding in fixes: # noqa: B007 for _position, embedding in fixes:
used_embeddings[embedding.name] = embedding used_embeddings[embedding.name] = embedding
z = self.process_tokens(tokens, multipliers) z = self.process_tokens(tokens, multipliers)

View File

@ -381,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"), "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"), "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(hypernetworks.keys())}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", hypernetworks]}, refresh=reload_hypernetworks),
})) }))
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {

View File

@ -166,8 +166,7 @@ class EmbeddingDatabase:
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
param_dict = data['string_to_param'] param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'): param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts # diffuser concepts

View File

@ -1230,8 +1230,8 @@ def create_ui():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=list(shared.hypernetworks.keys())) train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks.keys())}, "refresh_train_hypernetwork_name") create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
with FormRow(): with FormRow():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")