train: change filename processing to be more simple and configurable

train: make it possible to make text files with prompts
train: rework scheduler so that there's less repeating code in textual inversion and hypernets
train: move epochs setting to options
This commit is contained in:
AUTOMATIC 2022-10-12 20:49:47 +03:00
parent cc5803603b
commit c3c8eef9fd
7 changed files with 106 additions and 63 deletions

View File

@ -81,6 +81,9 @@ titles = {
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
} }

View File

@ -14,7 +14,7 @@ import torch
from torch import einsum from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps: if ititial_step > steps:
return hypernetwork, filename return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar: for i, entry in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if hypernetwork.step > end_step: scheduler.apply(optimizer, hypernetwork.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules) break
except Exception:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = cond.to(devices.device) cond = entry.cond.to(devices.device)
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x del x
del cond del cond
@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
) )
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0] if len(processed.images)>0 else None
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image if image is not None:
image.save(last_saved_image) shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {

View File

@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared from modules import devices, shared
import re import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception: except Exception:
continue continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path) filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens) if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8) npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32) npimage = (npimage / 127.5 - 1.0).astype(np.float32)
@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu) init_latent = init_latent.to(devices.cpu)
if include_cond: entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None
self.dataset.append((init_latent, filename_tokens, cond)) if include_cond:
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
self.dataset.append(entry)
self.length = len(self.dataset) * repeats self.length = len(self.dataset) * repeats
@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_tokens): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token) text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens)) text = text.replace("[filewords]", filename_text)
return text return text
def __len__(self): def __len__(self):
@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
index = self.indexes[i % len(self.indexes)] index = self.indexes[i % len(self.indexes)]
x, filename_tokens, cond = self.dataset[index] entry = self.dataset[index]
text = self.create_text(filename_tokens) if entry.cond is None:
return x, text, cond entry.cond_text = self.create_text(entry.filename_text)
return entry

View File

@ -1,6 +1,12 @@
import tqdm
class LearnSchedule:
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0): def __init__(self, learn_rate, max_steps, cur_step=0):
"""
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
"""
pairs = learn_rate.split(',') pairs = learn_rate.split(',')
self.rates = [] self.rates = []
self.it = 0 self.it = 0
@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1] return self.rates[self.it - 1]
else: else:
raise StopIteration raise StopIteration
class LearnRateScheduler:
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
(self.learn_rate, self.end_step) = next(self.schedules)
self.verbose = verbose
if self.verbose:
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
return
try:
(self.learn_rate, self.end_step) = next(self.schedules)
except Exception:
self.finished = True
return
if self.verbose:
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
for pg in optimizer.param_groups:
pg['lr'] = self.learn_rate

View File

@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed, insert_image_data_embed, extract_image_data_embed,
@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps: if ititial_step > steps:
return embedding, filename return embedding, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
(learn_rate, end_step) = next(schedules) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text, _) in pbar: for i, entry in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
if embedding.step > end_step: scheduler.apply(optimizer, embedding.step)
try: if scheduler.finished:
(learn_rate, end_step) = next(schedules) break
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
if shared.state.interrupted: if shared.state.interrupted:
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([text]) c = cond_model([entry.cond_text])
x = x.to(devices.device) x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0] loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/> Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>

View File

@ -1098,7 +1098,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0) steps = gr.Number(label='Max steps', value=100000, precision=0)
num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
@ -1176,7 +1175,6 @@ def create_ui(wrap_gradio_gpu_call):
training_width, training_width,
training_height, training_height,
steps, steps,
num_repeats,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
template_file, template_file,