Namespace metadata fields

This commit is contained in:
space-nuko 2023-04-02 21:41:23 -05:00
parent 7c016dd642
commit fbaf6e4fd8

View File

@ -242,7 +242,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving" shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
metadata = {"format": "pt", "models": {}, "merge_recipe": None} metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
if save_metadata: if save_metadata:
merge_recipe = { merge_recipe = {
@ -260,17 +260,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting": result_is_inpainting_model, "is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model "is_instruct_pix2pix": result_is_instruct_pix2pix_model
} }
metadata["merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
def add_model_metadata(checkpoint_info): def add_model_metadata(checkpoint_info):
checkpoint_info.calculate_shorthash() checkpoint_info.calculate_shorthash()
metadata["models"][checkpoint_info.sha256] = { metadata["sd_merge_models"][checkpoint_info.sha256] = {
"name": checkpoint_info.name, "name": checkpoint_info.name,
"legacy_hash": checkpoint_info.hash, "legacy_hash": checkpoint_info.hash,
"merge_recipe": checkpoint_info.metadata.get("merge_recipe", None) "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
} }
metadata["models"].update(checkpoint_info.metadata.get("models", {})) metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
add_model_metadata(primary_model_info) add_model_metadata(primary_model_info)
if secondary_model_info: if secondary_model_info:
@ -278,7 +278,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info: if tertiary_model_info:
add_model_metadata(tertiary_model_info) add_model_metadata(tertiary_model_info)
metadata["models"] = json.dumps(metadata["models"]) metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":