From d4ea5f4d8631f778d11efcde397e4a5b8801d43b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 19:03:08 +0300 Subject: [PATCH] add an option to unload models during hypernetwork training to save VRAM --- modules/hypernetworks/hypernetwork.py | 25 +++++++++++----- modules/hypernetworks/ui.py | 4 ++- modules/shared.py | 4 +++ modules/textual_inversion/dataset.py | 29 +++++++++++++------ .../textual_inversion/textual_inversion.py | 2 +- 5 files changed, 46 insertions(+), 18 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b081f14e..4700e1ec 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training if save_hypernetwork_every > 0: hypernetwork_dir = os.path.join(log_directory, "hypernetworks") @@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, else: images_dir = None - cond_model = shared.sd_model.cond_stage_model - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() @@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, return hypernetwork, filename pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: + for i, (x, text, cond) in pbar: hypernetwork.step = i + ititial_step if hypernetwork.step > steps: @@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - c = cond_model([text]) - + cond = cond.to(devices.device) x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] + loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x + del cond losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, preview_text = text if preview_image_prompt == "" else preview_image_prompt + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=preview_text, @@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, processed = processing.process_images(p) image = processed.images[0] + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + shared.state.current_image = image image.save(last_saved_image) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 3541a388..c67facbb 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -5,7 +5,7 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared +from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork @@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)} raise finally: shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/shared.py b/modules/shared.py index 20b45f23..c1092ff7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), +})) + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 4d006366..f61f40d3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,14 +8,14 @@ from torchvision import transforms import random import tqdm -from modules import devices +from modules import devices, shared import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): self.placeholder_token = placeholder_token @@ -32,6 +32,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' + cond_model = shared.sd_model.cond_stage_model + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -53,7 +55,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) - self.dataset.append((init_latent, filename_tokens)) + if include_cond: + text = self.create_text(filename_tokens) + cond = cond_model([text]).to(devices.cpu) + else: + cond = None + + self.dataset.append((init_latent, filename_tokens, cond)) self.length = len(self.dataset) * repeats @@ -64,6 +72,12 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + def create_text(self, filename_tokens): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + return text + def __len__(self): return self.length @@ -72,10 +86,7 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens = self.dataset[index] + x, filename_tokens, cond = self.dataset[index] - text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) - - return x, text + text = self.create_text(filename_tokens) + return x, text, cond diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb05cdc6..35f4bd9e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini return embedding, filename pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text) in pbar: + for i, (x, text, _) in pbar: embedding.step = i + ititial_step if embedding.step > steps: