add embedding load and save from b64 json

This commit is contained in:
DepFA 2022-10-09 21:58:14 +01:00 committed by GitHub
parent fa0c5eb81b
commit 03694e1f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,9 +7,11 @@ import tqdm
import html
import datetime
from PIL import Image, PngImagePlugin
from PIL import Image,PngImagePlugin
from ..images import captionImge
import numpy as np
import base64
from io import BytesIO
import json
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
@ -87,9 +89,9 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-embedding' in embed_image.text:
embeddingData = base64.b64decode(embed_image.text['sd-embedding'])
data = torch.load(BytesIO(embeddingData), map_location="cpu")
if 'sd-ti-embedding' in embed_image.text:
data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
name = data.get('name',name)
else:
data = torch.load(path, map_location="cpu")
@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo()
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read()))
image.save(last_saved_image, "PNG", pnginfo=info)
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
caption_stepcount = data.get('step',0)
caption_stepcount = caption_stepcount if caption_stepcount else 0
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
caption_stepcount))]
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else:
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step