suggestions and fixes from the PR
This commit is contained in:
parent
d25219b7e8
commit
3ec7b705c7
@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||
|
||||
|
||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(lora.available_loras)}, refresh=lora.list_available_loras),
|
||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||
}))
|
||||
|
||||
|
||||
|
@ -644,17 +644,13 @@ class SwinIR(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=None, num_heads=None,
|
||||
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(SwinIR, self).__init__()
|
||||
|
||||
depths = depths or [6, 6, 6, 6]
|
||||
num_heads = num_heads or [6, 6, 6, 6]
|
||||
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
|
@ -74,12 +74,9 @@ class WindowAttention(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||
pretrained_window_size=None):
|
||||
pretrained_window_size=(0, 0)):
|
||||
|
||||
super().__init__()
|
||||
|
||||
pretrained_window_size = pretrained_window_size or [0, 0]
|
||||
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pretrained_window_size = pretrained_window_size
|
||||
@ -701,17 +698,13 @@ class Swin2SR(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||
embed_dim=96, depths=None, num_heads=None,
|
||||
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||
**kwargs):
|
||||
super(Swin2SR, self).__init__()
|
||||
|
||||
depths = depths or [6, 6, 6, 6]
|
||||
num_heads = num_heads or [6, 6, 6, 6]
|
||||
|
||||
num_in_ch = in_chans
|
||||
num_out_ch = in_chans
|
||||
num_feat = 64
|
||||
|
@ -161,13 +161,10 @@ class Fuse_sft_block(nn.Module):
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=None,
|
||||
fix_modules=None):
|
||||
connect_list=('32', '64', '128', '256'),
|
||||
fix_modules=('quantize', 'generator')):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
connect_list = connect_list or ['32', '64', '128', '256']
|
||||
fix_modules = fix_modules or ['quantize', 'generator']
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
|
@ -5,13 +5,13 @@ import modules.hypernetworks.hypernetwork
|
||||
from modules import devices, sd_hijack, shared
|
||||
|
||||
not_available = ["hardswish", "multiheadattention"]
|
||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available]
|
||||
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||
|
||||
|
||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||
|
||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks.keys())), f"Created: {filename}", ""
|
||||
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||
|
||||
|
||||
def train_hypernetwork(*args):
|
||||
|
@ -275,8 +275,8 @@ def model_wrapper(
|
||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||
"""
|
||||
|
||||
model_kwargs = model_kwargs or []
|
||||
classifier_kwargs = classifier_kwargs or []
|
||||
model_kwargs = model_kwargs or {}
|
||||
classifier_kwargs = classifier_kwargs or {}
|
||||
|
||||
def get_model_input_time(t_continuous):
|
||||
"""
|
||||
|
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
||||
script_args = args[script.args_from:script.args_to]
|
||||
|
||||
process_args = {}
|
||||
for (name, component), value in zip(script.controls.items(), script_args): # noqa B007
|
||||
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||
process_args[name] = value
|
||||
|
||||
script.process(pp, **process_args)
|
||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||
|
||||
for fixes in self.hijack.fixes:
|
||||
for position, embedding in fixes: # noqa: B007
|
||||
for _position, embedding in fixes:
|
||||
used_embeddings[embedding.name] = embedding
|
||||
|
||||
z = self.process_tokens(tokens, multipliers)
|
||||
|
@ -381,7 +381,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + list(hypernetworks.keys())}, refresh=reload_hypernetworks),
|
||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", hypernetworks]}, refresh=reload_hypernetworks),
|
||||
}))
|
||||
|
||||
options_templates.update(options_section(('ui', "User interface"), {
|
||||
|
@ -166,8 +166,7 @@ class EmbeddingDatabase:
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = param_dict._parameters # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
|
@ -1230,8 +1230,8 @@ def create_ui():
|
||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
|
||||
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=list(shared.hypernetworks.keys()))
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks.keys())}, "refresh_train_hypernetwork_name")
|
||||
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=sorted(shared.hypernetworks))
|
||||
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted(shared.hypernetworks)}, "refresh_train_hypernetwork_name")
|
||||
|
||||
with FormRow():
|
||||
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
|
||||
|
Loading…
Reference in New Issue
Block a user