From ab27c111d06ec920791c73eea25ad9a61671852e Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 18:09:17 +0700 Subject: [PATCH 1/6] Add input validations before loading dataset for training --- modules/hypernetworks/hypernetwork.py | 38 ++++++++------ .../textual_inversion/textual_inversion.py | 50 ++++++++++++++----- 2 files changed, 59 insertions(+), 29 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 2e84583b..38f35c58 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -332,7 +332,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images - assert hypernetwork_name, 'hypernetwork not selected' + save_hypernetwork_every = save_hypernetwork_every or 0 + create_image_every = create_image_every or 0 + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() @@ -358,39 +360,43 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log else: images_dir = None + hypernetwork = shared.loaded_hypernetwork + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return hypernetwork, filename + + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - size = len(ds.indexes) loss_dict = defaultdict(lambda : deque(maxlen = 1024)) losses = torch.zeros((size,)) previous_mean_losses = [0] previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - - last_saved_file = "" - last_saved_image = "" - forced_filename = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) steps_without_grad = 0 + last_saved_file = "" + last_saved_image = "" + forced_filename = "" + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 17dfb223..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): + assert model_name, f"{name} not selected" + assert learn_rate, "Learning rate is empty or 0" + assert isinstance(batch_size, int), "Batch size must be integer" + assert batch_size > 0, "Batch size must be positive" + assert data_root, "Dataset directory is empty" + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" + assert template_file, "Prompt template file is empty" + assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert steps, "Max steps is empty or 0" + assert isinstance(steps, int), "Max steps must be integer" + assert steps > 0 , "Max steps must be positive" + assert isinstance(save_model_every, int), "Save {name} must be integer" + assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert isinstance(create_image_every, int), "Create image must be integer" + assert create_image_every >= 0 , "Create image must be positive or 0" + if save_model_every or create_image_every: + assert log_directory, "Log directory is empty" def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - assert embedding_name, 'embedding not selected' + save_embedding_every = save_embedding_every or 0 + create_image_every = create_image_every or 0 + validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps @@ -232,17 +253,27 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc os.makedirs(images_embeds_dir, exist_ok=True) else: images_embeds_dir = None - - cond_model = shared.sd_model.cond_stage_model - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + cond_model = shared.sd_model.cond_stage_model hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] + + ititial_step = embedding.step or 0 + if ititial_step > steps: + shared.state.textinfo = f"Model has already been trained beyond specified max steps" + return embedding, filename + + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + + # dataset loading may take a while, so input validations and early returns should be done before this + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + embedding.vec.requires_grad = True + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) losses = torch.zeros((32,)) @@ -251,13 +282,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc forced_filename = "" embedding_yet_to_be_embedded = False - ititial_step = embedding.step or 0 - if ititial_step > steps: - return embedding, filename - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step From 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:43:21 +0700 Subject: [PATCH 2/6] Add cleanup after training --- modules/hypernetworks/hypernetwork.py | 187 +++++++++--------- .../textual_inversion/textual_inversion.py | 163 +++++++-------- 2 files changed, 182 insertions(+), 168 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..170d5ea4 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -398,110 +398,112 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log forced_filename = "" pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) + try: + for i, entries in pbar: + hypernetwork.step = i + ititial_step + if len(loss_dict) > 0: + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) - optimizer.zero_grad() - weights[0].grad = None - loss.backward() + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break - if weights[0].grad is None: - steps_without_grad += 1 + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = stack_conds([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + del c + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + for entry in entries: + loss_dict[entry.filename].append(loss.item()) + + optimizer.zero_grad() + weights[0].grad = None + loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 + else: + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + + optimizer.step() + + steps_done = hypernetwork.step + 1 + + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + raise RuntimeError("Loss diverged.") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) - optimizer.step() + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') + hypernetwork.save(last_saved_file) - steps_done = hypernetwork.step + 1 + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{previous_mean_loss:.7f}", + "learn_rate": scheduler.learn_rate + }) - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") - - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) - else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) + preview_text = p.prompt - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) - preview_text = p.prompt + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + shared.state.job_no = hypernetwork.step - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) - - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -510,7 +512,14 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - + finally: + if weights: + for weight in weights: + weight.requires_grad = False + if unload: + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..fd7f0897 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,111 +283,113 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, entries in pbar: - embedding.step = i + ititial_step - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break + try: + for i, entries in pbar: + embedding.step = i + ititial_step - if shared.state.interrupted: - break + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x + if shared.state.interrupted: + break - losses[embedding.step % losses.shape[0]] = loss.item() + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x - optimizer.zero_grad() - loss.backward() - optimizer.step() + losses[embedding.step % losses.shape[0]] = loss.item() - steps_done = embedding.step + 1 + optimizer.zero_grad() + loss.backward() + optimizer.step() - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) + steps_done = embedding.step + 1 - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) - preview_text = p.prompt + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - processed = processing.process_images(p) - image = processed.images[0] + preview_text = p.prompt - shared.state.current_image = image + processed = processing.process_images(p) + image = processed.images[0] - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + shared.state.current_image = image - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - title = "<{}>".format(data.get('name', '???')) + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + title = "<{}>".format(data.get('name', '???')) - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - shared.state.job_no = embedding.step + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - shared.state.textinfo = f""" + shared.state.job_no = embedding.step + + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" + finally: + if embedding and embedding.vec is not None: + embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() From a27d19de2eff633b6a39f9f4a5c0f2d6abb81bb5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 29 Oct 2022 19:44:05 +0700 Subject: [PATCH 3/6] Additional assert on dataset --- modules/textual_inversion/dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 8bb00d27..ad726577 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -42,6 +42,8 @@ class PersonalizedBase(Dataset): self.lines = lines assert data_root, 'dataset directory not specified' + assert os.path.isdir(data_root), "Dataset directory doesn't exist" + assert os.listdir(data_root), "Dataset directory is empty" cond_model = shared.sd_model.cond_stage_model From ab05a74ead9fabb45dd099990e34061c7eb02ca3 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:32:02 +0700 Subject: [PATCH 4/6] Revert "Add cleanup after training" This reverts commit 3ce2bfdf95bd5f26d0f6e250e67338ada91980d1. --- modules/hypernetworks/hypernetwork.py | 191 +++++++++--------- .../textual_inversion/textual_inversion.py | 163 ++++++++------- 2 files changed, 170 insertions(+), 184 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 170d5ea4..38f35c58 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -398,112 +398,110 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log forced_filename = "" pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - - try: - for i, entries in pbar: - hypernetwork.step = i + ititial_step - if len(loss_dict) > 0: - previous_mean_losses = [i[-1] for i in loss_dict.values()] - previous_mean_loss = mean(previous_mean_losses) - - scheduler.apply(optimizer, hypernetwork.step) - if scheduler.finished: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = stack_conds([entry.cond for entry in entries]).to(devices.device) - # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x - del c - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - for entry in entries: - loss_dict[entry.filename].append(loss.item()) - - optimizer.zero_grad() - weights[0].grad = None - loss.backward() - - if weights[0].grad is None: - steps_without_grad += 1 - else: - steps_without_grad = 0 - assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - - optimizer.step() - - steps_done = hypernetwork.step + 1 - - if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): - raise RuntimeError("Loss diverged.") + for i, entries in pbar: + hypernetwork.step = i + ititial_step + if len(loss_dict) > 0: + previous_mean_losses = [i[-1] for i in loss_dict.values()] + previous_mean_loss = mean(previous_mean_losses) - if len(previous_mean_losses) > 1: - std = stdev(previous_mean_losses) + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = stack_conds([entry.cond for entry in entries]).to(devices.device) + # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x + del c + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + for entry in entries: + loss_dict[entry.filename].append(loss.item()) + + optimizer.zero_grad() + weights[0].grad = None + loss.backward() + + if weights[0].grad is None: + steps_without_grad += 1 else: - std = 0 - dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" - pbar.set_description(dataset_loss_info) + steps_without_grad = 0 + assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: - # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + optimizer.step() - textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { - "loss": f"{previous_mean_loss:.7f}", - "learn_rate": scheduler.learn_rate - }) + steps_done = hypernetwork.step + 1 - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{hypernetwork_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) + if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): + raise RuntimeError("Loss diverged.") + + if len(previous_mean_losses) > 1: + std = stdev(previous_mean_losses) + else: + std = 0 + dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" + pbar.set_description(dataset_loss_info) - optimizer.zero_grad() - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) + if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: + # Before saving, change name to match current checkpoint. + hypernetwork.name = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') + hypernetwork.save(last_saved_file) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - ) + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{previous_mean_loss:.7f}", + "learn_rate": scheduler.learn_rate + }) - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{hypernetwork_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) - preview_text = p.prompt + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) - processed = processing.process_images(p) - image = processed.images[0] if len(processed.images)>0 else None + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + ) - if unload: - shared.sd_model.cond_stage_model.to(devices.cpu) - shared.sd_model.first_stage_model.to(devices.cpu) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 - if image is not None: - shared.state.current_image = image - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + preview_text = p.prompt - shared.state.job_no = hypernetwork.step + processed = processing.process_images(p) + image = processed.images[0] if len(processed.images)>0 else None - shared.state.textinfo = f""" + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + + if image is not None: + shared.state.current_image = image + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f"""

Loss: {previous_mean_loss:.7f}
Step: {hypernetwork.step}
@@ -512,14 +510,7 @@ Last saved hypernetwork: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - finally: - if weights: - for weight in weights: - weight.requires_grad = False - if unload: - shared.sd_model.cond_stage_model.to(devices.device) - shared.sd_model.first_stage_model.to(devices.device) - + report_statistics(loss_dict) checkpoint = sd_models.select_checkpoint() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd7f0897..44f06443 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -283,113 +283,111 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, entries in pbar: + embedding.step = i + ititial_step - try: - for i, entries in pbar: - embedding.step = i + ititial_step + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break - scheduler.apply(optimizer, embedding.step) - if scheduler.finished: - break + if shared.state.interrupted: + break - if shared.state.interrupted: - break + with torch.autocast("cuda"): + c = cond_model([entry.cond_text for entry in entries]) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) + loss = shared.sd_model(x, c)[0] + del x - with torch.autocast("cuda"): - c = cond_model([entry.cond_text for entry in entries]) - x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] - del x + losses[embedding.step % losses.shape[0]] = loss.item() - losses[embedding.step % losses.shape[0]] = loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() - optimizer.zero_grad() - loss.backward() - optimizer.step() + steps_done = embedding.step + 1 - steps_done = embedding.step + 1 + epoch_num = embedding.step // len(ds) + epoch_step = embedding.step % len(ds) - epoch_num = embedding.step // len(ds) - epoch_step = embedding.step % len(ds) + pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}") + if embedding_dir is not None and steps_done % save_embedding_every == 0: + # Before saving, change name to match current checkpoint. + embedding.name = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') + embedding.save(last_saved_file) + embedding_yet_to_be_embedded = True - if embedding_dir is not None and steps_done % save_embedding_every == 0: - # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) - embedding_yet_to_be_embedded = True + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) - write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate - }) + if images_dir is not None and steps_done % create_image_every == 0: + forced_filename = f'{embedding_name}-{steps_done}' + last_saved_image = os.path.join(images_dir, forced_filename) + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + do_not_save_grid=True, + do_not_save_samples=True, + do_not_reload_embeddings=True, + ) - if images_dir is not None and steps_done % create_image_every == 0: - forced_filename = f'{embedding_name}-{steps_done}' - last_saved_image = os.path.join(images_dir, forced_filename) - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - do_not_save_grid=True, - do_not_save_samples=True, - do_not_reload_embeddings=True, - ) + if preview_from_txt2img: + p.prompt = preview_prompt + p.negative_prompt = preview_negative_prompt + p.steps = preview_steps + p.sampler_index = preview_sampler_index + p.cfg_scale = preview_cfg_scale + p.seed = preview_seed + p.width = preview_width + p.height = preview_height + else: + p.prompt = entries[0].cond_text + p.steps = 20 + p.width = training_width + p.height = training_height - if preview_from_txt2img: - p.prompt = preview_prompt - p.negative_prompt = preview_negative_prompt - p.steps = preview_steps - p.sampler_index = preview_sampler_index - p.cfg_scale = preview_cfg_scale - p.seed = preview_seed - p.width = preview_width - p.height = preview_height - else: - p.prompt = entries[0].cond_text - p.steps = 20 - p.width = training_width - p.height = training_height + preview_text = p.prompt - preview_text = p.prompt + processed = processing.process_images(p) + image = processed.images[0] - processed = processing.process_images(p) - image = processed.images[0] + shared.state.current_image = image - shared.state.current_image = image + if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: - if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: + last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') - last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') + info = PngImagePlugin.PngInfo() + data = torch.load(last_saved_file) + info.add_text("sd-ti-embedding", embedding_to_b64(data)) - info = PngImagePlugin.PngInfo() - data = torch.load(last_saved_file) - info.add_text("sd-ti-embedding", embedding_to_b64(data)) + title = "<{}>".format(data.get('name', '???')) - title = "<{}>".format(data.get('name', '???')) + try: + vectorSize = list(data['string_to_param'].values())[0].shape[0] + except Exception as e: + vectorSize = '?' - try: - vectorSize = list(data['string_to_param'].values())[0].shape[0] - except Exception as e: - vectorSize = '?' + checkpoint = sd_models.select_checkpoint() + footer_left = checkpoint.model_name + footer_mid = '[{}]'.format(checkpoint.hash) + footer_right = '{}v {}s'.format(vectorSize, steps_done) - checkpoint = sd_models.select_checkpoint() - footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) - footer_right = '{}v {}s'.format(vectorSize, steps_done) + captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) + captioned_image = insert_image_data_embed(captioned_image, data) - captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) - captioned_image = insert_image_data_embed(captioned_image, data) + captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) + embedding_yet_to_be_embedded = False - captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) - embedding_yet_to_be_embedded = False + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) + last_saved_image += f", prompt: {preview_text}" - last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) - last_saved_image += f", prompt: {preview_text}" + shared.state.job_no = embedding.step - shared.state.job_no = embedding.step - - shared.state.textinfo = f""" + shared.state.textinfo = f"""

Loss: {losses.mean():.7f}
Step: {embedding.step}
@@ -398,9 +396,6 @@ Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" - finally: - if embedding and embedding.vec is not None: - embedding.vec.requires_grad = False checkpoint = sd_models.select_checkpoint() From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:49:29 +0700 Subject: [PATCH 5/6] Add missing info on hypernetwork/embedding model log Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513 Also group the saving into one --- modules/hypernetworks/hypernetwork.py | 31 ++++++++++----- .../textual_inversion/textual_inversion.py | 39 ++++++++++++------- 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..86daf825 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log images_dir = None hypernetwork = shared.loaded_hypernetwork + checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", @@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
""" report_statistics(loss_dict) - checkpoint = sd_models.select_checkpoint() - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). - hypernetwork.name = hypernetwork_name - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(filename) + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) return hypernetwork, filename + +def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): + old_hypernetwork_name = hypernetwork.name + old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None + old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None + try: + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.name = hypernetwork_name + hypernetwork.save(filename) + except: + hypernetwork.sd_checkpoint = old_sd_checkpoint + hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name + hypernetwork.name = old_hypernetwork_name + raise diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 44f06443..ee9917ce 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -119,7 +119,7 @@ class EmbeddingDatabase: vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('hash', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) self.register_embedding(embedding, shared.sd_model) @@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] + checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 if ititial_step > steps: @@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. - embedding.name = f'{embedding_name}-{steps_done}' - last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') - embedding.save(last_saved_file) + embedding_name_every = f'{embedding_name}-{steps_done}' + last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') + save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { @@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}

""" - checkpoint = sd_models.select_checkpoint() - - embedding.sd_checkpoint = checkpoint.hash - embedding.sd_checkpoint_name = checkpoint.model_name - embedding.cached_checksum = None - # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). - embedding.name = embedding_name - filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt') - embedding.save(filename) + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) return embedding, filename + +def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): + old_embedding_name = embedding.name + old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None + old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None + old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None + try: + embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint_name = checkpoint.model_name + if remove_cached_checksum: + embedding.cached_checksum = None + embedding.name = embedding_name + embedding.save(filename) + except: + embedding.sd_checkpoint = old_sd_checkpoint + embedding.sd_checkpoint_name = old_sd_checkpoint_name + embedding.name = old_embedding_name + embedding.cached_checksum = old_cached_checksum + raise From 3d58510f214c645ce5cdb261aa47df6573b239e9 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:54:59 +0700 Subject: [PATCH 6/6] Fix dataset still being loaded even when training will be skipped --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/textual_inversion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 86daf825..07acadc9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -364,7 +364,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 - if ititial_step > steps: + if ititial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return hypernetwork, filename diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ee9917ce..e0babb46 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -262,7 +262,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc checkpoint = sd_models.select_checkpoint() ititial_step = embedding.step or 0 - if ititial_step > steps: + if ititial_step >= steps: shared.state.textinfo = f"Model has already been trained beyond specified max steps" return embedding, filename