diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b8695fc1..7d519cd9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -22,45 +22,86 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - def __init__(self, dim, state_dict=None): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() + if layer_structure is not None: + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" + else: + layer_structure = parse_layer_structure(dim, state_dict) - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) + linears = [] + for i in range(len(layer_structure) - 1): + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) + + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - self.load_state_dict(state_dict, strict=True) + try: + self.load_state_dict(state_dict) + except RuntimeError: + self.try_load_previous(state_dict) else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() + for layer in self.linear: + layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.bias.data.zero_() self.to(devices.device) + def try_load_previous(self, state_dict): + states = self.state_dict() + states['linear.0.bias'].copy_(state_dict['linear1.bias']) + states['linear.0.weight'].copy_(state_dict['linear1.weight']) + states['linear.1.bias'].copy_(state_dict['linear2.bias']) + states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def forward(self, x): - return x + (self.linear2(self.linear1(x))) * self.multiplier + return x + self.linear(x) * self.multiplier + + def trainables(self): + layer_structure = [] + for layer in self.linear: + layer_structure += [layer.weight, layer.bias] + return layer_structure def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +def parse_layer_structure(dim, state_dict): + i = 0 + layer_structure = [1] + + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + layer_structure.append(len(weight) // dim) + i += 1 + + return layer_structure + + class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): self.filename = None self.name = name self.layers = {} self.step = 0 self.sd_checkpoint = None self.sd_checkpoint_name = None + self.layer_structure = layer_structure + self.add_layer_norm = add_layer_norm for size in enable_sizes or []: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + self.layers[size] = ( + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + ) def weights(self): res = [] @@ -68,7 +109,7 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + res += layer.trainables() return res @@ -80,6 +121,8 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name + state_dict['layer_structure'] = self.layer_structure + state_dict['is_layer_norm'] = self.add_layer_norm state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -94,10 +137,15 @@ class Hypernetwork: for size, sd in state_dict.items(): if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + self.layers[size] = ( + HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) + self.layer_structure = state_dict.get('layer_structure', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) @@ -226,7 +274,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) - if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -261,7 +308,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) -# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] del x diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index dfa599af..7e8ea95e 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes): +def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes]) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( + name=name, + enable_sizes=[int(x) for x in enable_sizes], + layer_structure=layer_structure, + add_layer_norm=add_layer_norm, + ) hypernet.save(fn) shared.reload_hypernetworks() diff --git a/modules/shared.py b/modules/shared.py index f7d66870..faede821 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -137,7 +137,7 @@ class State: self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 - + def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? diff --git a/modules/ui.py b/modules/ui.py index 1ff7eb4f..e5940063 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -477,14 +477,14 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" ) with gr.Row(): with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" ) @@ -1217,6 +1217,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") with gr.Row(): with gr.Column(scale=3): @@ -1299,6 +1301,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + new_hypernetwork_layer_structure, + new_hypernetwork_add_layer_norm, ], outputs=[ train_hypernetwork_name, diff --git a/webui.py b/webui.py index 71724c3b..177bef74 100644 --- a/webui.py +++ b/webui.py @@ -140,7 +140,7 @@ def webui(launch_api=False): create_api(app) wait_on_server(demo) - + sd_samplers.set_samplers() print('Reloading Custom Scripts') @@ -158,4 +158,4 @@ if __name__ == "__main__": if cmd_opts.nowebui: api_only() else: - webui(cmd_opts.api) \ No newline at end of file + webui(cmd_opts.api)