Lora support for SD2
This commit is contained in:
parent
b705c9b72b
commit
650ddc9dd3
@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors
|
|||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
re_digits = re.compile(r"\d+")
|
re_digits = re.compile(r"\d+")
|
||||||
re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
|
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
|
||||||
re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
|
re_compiled = {}
|
||||||
re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
|
|
||||||
re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
|
suffix_conversion = {
|
||||||
|
"attentions": {},
|
||||||
|
"resnets": {
|
||||||
|
"conv1": "in_layers_2",
|
||||||
|
"conv2": "out_layers_3",
|
||||||
|
"time_emb_proj": "emb_layers_1",
|
||||||
|
"conv_shortcut": "skip_connection",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def convert_diffusers_name_to_compvis(key, is_sd2):
|
def convert_diffusers_name_to_compvis(key, is_sd2):
|
||||||
def match(match_list, regex):
|
def match(match_list, regex_text):
|
||||||
|
regex = re_compiled.get(regex_text)
|
||||||
|
if regex is None:
|
||||||
|
regex = re.compile(regex_text)
|
||||||
|
re_compiled[regex_text] = regex
|
||||||
|
|
||||||
r = re.match(regex, key)
|
r = re.match(regex, key)
|
||||||
if not r:
|
if not r:
|
||||||
return False
|
return False
|
||||||
@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2):
|
|||||||
|
|
||||||
m = []
|
m = []
|
||||||
|
|
||||||
if match(m, re_unet_down_blocks):
|
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
if match(m, re_unet_mid_blocks):
|
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
return f"diffusion_model_middle_block_1_{m[1]}"
|
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
|
||||||
|
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
|
||||||
|
|
||||||
if match(m, re_unet_up_blocks):
|
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
|
||||||
return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
|
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
|
||||||
|
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
|
||||||
|
|
||||||
if match(m, re_text_block):
|
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
|
||||||
|
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
|
||||||
|
|
||||||
|
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
|
||||||
|
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
|
||||||
|
|
||||||
|
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
|
||||||
if is_sd2:
|
if is_sd2:
|
||||||
if 'mlp_fc1' in m[1]:
|
if 'mlp_fc1' in m[1]:
|
||||||
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
|
||||||
@ -109,16 +131,22 @@ def load_lora(name, filename):
|
|||||||
|
|
||||||
sd = sd_models.read_state_dict(filename)
|
sd = sd_models.read_state_dict(filename)
|
||||||
|
|
||||||
keys_failed_to_match = []
|
keys_failed_to_match = {}
|
||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||||
|
|
||||||
for key_diffusers, weight in sd.items():
|
for key_diffusers, weight in sd.items():
|
||||||
fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
|
key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
|
||||||
key, lora_key = fullkey.split(".", 1)
|
key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
|
||||||
|
|
||||||
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
|
||||||
|
|
||||||
if sd_module is None:
|
if sd_module is None:
|
||||||
keys_failed_to_match.append(key_diffusers)
|
m = re_x_proj.match(key)
|
||||||
|
if m:
|
||||||
|
sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
|
||||||
|
|
||||||
|
if sd_module is None:
|
||||||
|
keys_failed_to_match[key_diffusers] = key
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_module = lora.modules.get(key, None)
|
lora_module = lora.modules.get(key, None)
|
||||||
@ -133,7 +161,9 @@ def load_lora(name, filename):
|
|||||||
if type(sd_module) == torch.nn.Linear:
|
if type(sd_module) == torch.nn.Linear:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
|
||||||
module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d:
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
else:
|
else:
|
||||||
@ -190,54 +220,94 @@ def load_loras(names, multipliers=None):
|
|||||||
loaded_loras.append(lora)
|
loaded_loras.append(lora)
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
|
def lora_calc_updown(lora, module, target):
|
||||||
|
with torch.no_grad():
|
||||||
|
up = module.up.weight.to(target.device, dtype=target.dtype)
|
||||||
|
down = module.down.weight.to(target.device, dtype=target.dtype)
|
||||||
|
|
||||||
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
updown = up @ down
|
||||||
|
|
||||||
|
updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
return updown
|
||||||
|
|
||||||
|
|
||||||
|
def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of Loras to the weight of torch layer self.
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||||
If weights already have this particular set of loras applied, does nothing.
|
If weights already have this particular set of loras applied, does nothing.
|
||||||
If not, restores orginal weights from backup and alters weights according to loras.
|
If not, restores orginal weights from backup and alters weights according to loras.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
||||||
|
if lora_layer_name is None:
|
||||||
|
return
|
||||||
|
|
||||||
current_names = getattr(self, "lora_current_names", ())
|
current_names = getattr(self, "lora_current_names", ())
|
||||||
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
|
||||||
|
|
||||||
weights_backup = getattr(self, "lora_weights_backup", None)
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
if weights_backup is None:
|
if weights_backup is None:
|
||||||
weights_backup = self.weight.to(devices.cpu, copy=True)
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
|
||||||
|
else:
|
||||||
|
weights_backup = self.weight.to(devices.cpu, copy=True)
|
||||||
|
|
||||||
self.lora_weights_backup = weights_backup
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
if weights_backup is not None:
|
if weights_backup is not None:
|
||||||
self.weight.copy_(weights_backup)
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
lora_layer_name = getattr(self, 'lora_layer_name', None)
|
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
|
if module is not None and hasattr(self, 'weight'):
|
||||||
|
self.weight += lora_calc_updown(lora, module, self.weight)
|
||||||
|
continue
|
||||||
|
|
||||||
|
module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
|
||||||
|
module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
|
||||||
|
module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
|
||||||
|
module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
|
||||||
|
updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
|
||||||
|
updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
|
||||||
|
updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
|
||||||
|
updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
|
||||||
|
|
||||||
|
self.in_proj_weight += updown_qkv
|
||||||
|
self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
|
||||||
|
continue
|
||||||
|
|
||||||
if module is None:
|
if module is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
with torch.no_grad():
|
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||||
up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
|
|
||||||
down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
|
|
||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
updown = up @ down
|
|
||||||
|
|
||||||
self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
|
||||||
|
|
||||||
setattr(self, "lora_current_names", wanted_names)
|
setattr(self, "lora_current_names", wanted_names)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear):
|
||||||
|
setattr(self, "lora_current_names", ())
|
||||||
|
setattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
|
def lora_Linear_load_state_dict(self, *args, **kwargs):
|
||||||
setattr(self, "lora_current_names", ())
|
lora_reset_cached_weight(self)
|
||||||
setattr(self, "lora_weights_backup", None)
|
|
||||||
|
|
||||||
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input):
|
|||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
|
def lora_Conv2d_load_state_dict(self, *args, **kwargs):
|
||||||
setattr(self, "lora_current_names", ())
|
lora_reset_cached_weight(self)
|
||||||
setattr(self, "lora_weights_backup", None)
|
|
||||||
|
|
||||||
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def lora_NonDynamicallyQuantizableLinear_forward(self, input):
|
def lora_MultiheadAttention_forward(self, *args, **kwargs):
|
||||||
return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input))
|
lora_apply_weights(self)
|
||||||
|
|
||||||
|
return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
||||||
|
lora_reset_cached_weight(self)
|
||||||
|
|
||||||
|
return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
|
@ -12,6 +12,8 @@ def unload():
|
|||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
|
||||||
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
|
||||||
|
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
|
||||||
|
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
|
||||||
|
|
||||||
|
|
||||||
def before_ui():
|
def before_ui():
|
||||||
@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'):
|
|||||||
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
|
||||||
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
|
||||||
|
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
|
||||||
|
|
||||||
|
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
|
||||||
|
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
|
||||||
|
|
||||||
torch.nn.Linear.forward = lora.lora_Linear_forward
|
torch.nn.Linear.forward = lora.lora_Linear_forward
|
||||||
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
|
||||||
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
|
||||||
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
|
||||||
|
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
|
||||||
|
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
|
||||||
|
|
||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
|
Loading…
Reference in New Issue
Block a user