diff --git a/.gitignore b/.gitignore index 5985023..7bcfc48 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ venv venv1 mytraining.ps __pycache__ +.vscode \ No newline at end of file diff --git a/dreambooth_gui.py b/dreambooth_gui.py index 4f8aa26..bca3184 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -1,5 +1,6 @@ # v1: initial release # v2: add open and save folder icons +# v3: Add new Utilities tab for Dreambooth folder preparation import gradio as gr import json @@ -160,7 +161,8 @@ def open_configuration( # Return the values of the variables as a dictionary return ( file_path, - my_data.get("pretrained_model_name_or_path", pretrained_model_name_or_path), + my_data.get("pretrained_model_name_or_path", + pretrained_model_name_or_path), my_data.get("v2", v2), my_data.get("v_parameterization", v_parameterization), my_data.get("logging_dir", logging_dir), @@ -177,7 +179,8 @@ def open_configuration( my_data.get("mixed_precision", mixed_precision), my_data.get("save_precision", save_precision), my_data.get("seed", seed), - my_data.get("num_cpu_threads_per_process", num_cpu_threads_per_process), + my_data.get("num_cpu_threads_per_process", + num_cpu_threads_per_process), my_data.get("convert_to_safetensors", convert_to_safetensors), my_data.get("convert_to_ckpt", convert_to_ckpt), my_data.get("cache_latent", cache_latent), @@ -225,6 +228,7 @@ def train_model( use_8bit_adam, xformers, ): + def save_inference_file(output_dir, v2, v_parameterization): # Copy inference model for v2 if required if v2 and v_parameterization: @@ -242,8 +246,7 @@ def train_model( # Get a list of all subfolders in train_data_dir subfolders = [ - f - for f in os.listdir(train_data_dir) + f for f in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, f)) ] @@ -255,16 +258,11 @@ def train_model( repeats = int(folder.split("_")[0]) # Count the number of images in the folder - num_images = len( - [ - f - for f in os.listdir(os.path.join(train_data_dir, folder)) - if f.endswith(".jpg") - or f.endswith(".jpeg") - or f.endswith(".png") - or f.endswith(".webp") - ] - ) + num_images = len([ + f for f in os.listdir(os.path.join(train_data_dir, folder)) + if f.endswith(".jpg") or f.endswith(".jpeg") or f.endswith(".png") + or f.endswith(".webp") + ]) # Calculate the total number of steps for this folder steps = repeats * num_images @@ -287,9 +285,8 @@ def train_model( # calculate max_train_steps max_train_steps = int( math.ceil( - float(total_steps) / int(train_batch_size) * int(epoch) * int(reg_factor) - ) - ) + float(total_steps) / int(train_batch_size) * int(epoch) * + int(reg_factor))) print(f"max_train_steps = {max_train_steps}") # calculate stop encoder training @@ -297,8 +294,7 @@ def train_model( stop_text_encoder_training = 0 else: stop_text_encoder_training = math.ceil( - float(max_train_steps) / 100 * int(stop_text_encoder_training_pct) - ) + float(max_train_steps) / 100 * int(stop_text_encoder_training_pct)) print(f"stop_text_encoder_training = {stop_text_encoder_training}") lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100)) @@ -326,8 +322,9 @@ def train_model( if xformers: run_cmd += " --xformers" run_cmd += f" --pretrained_model_name_or_path={pretrained_model_name_or_path}" - run_cmd += f" --train_data_dir={train_data_dir}" - run_cmd += f" --reg_data_dir={reg_data_dir}" + run_cmd += f' --train_data_dir="{train_data_dir}"' + if len(reg_data_dir): + run_cmd += f' --reg_data_dir="{reg_data_dir}"' run_cmd += f" --resolution={max_resolution}" run_cmd += f" --output_dir={output_dir}" run_cmd += f" --train_batch_size={train_batch_size}" @@ -362,7 +359,9 @@ def train_model( save_inference_file(output_dir, v2, v_parameterization) if convert_to_safetensors: - print(f"Converting diffuser model {last_dir} to {last_dir}.safetensors") + print( + f"Converting diffuser model {last_dir} to {last_dir}.safetensors" + ) os.system( f"python ./tools/convert_diffusers20_original_sd.py {last_dir} {last_dir}.safetensors --{save_precision}" ) @@ -435,9 +434,9 @@ def remove_doublequote(file_path): def get_file_path(file_path): - file_path = fileopenbox( - "Select the config file to load", default=file_path, filetypes="*.json" - ) + file_path = fileopenbox("Select the config file to load", + default=file_path, + filetypes="*.json") return file_path @@ -448,6 +447,84 @@ def get_folder_path(): return folder_path +def dreambooth_folder_preparation( + util_training_images_dir_input, + util_training_images_repeat_input, + util_training_images_prompt_input, + util_regularization_images_dir_input, + util_regularization_images_repeat_input, + util_regularization_images_prompt_input, + util_training_dir_input, +): + + # Check if the input variables are empty + if (not len(util_training_dir_input)): + print( + "Destination training directory is missing... can't perform the required task..." + ) + return + else: + # Create the util_training_dir_input directory if it doesn't exist + os.makedirs(util_training_dir_input, exist_ok=True) + + # Create the training_dir path + if (not len(util_training_images_prompt_input) + or not util_training_images_repeat_input > 0): + print( + "Training images directory or repeats is missing... can't perform the required task..." + ) + return + else: + training_dir = os.path.join( + util_training_dir_input, + f"img/{int(util_training_images_repeat_input)}_{util_training_images_prompt_input}", + ) + + # Remove folders if they exist + if os.path.exists(training_dir): + print(f"Removing existing directory {training_dir}...") + shutil.rmtree(training_dir) + + # Copy the training images to their respective directories + print(f"Copy {util_training_images_dir_input} to {training_dir}...") + shutil.copytree(util_training_images_dir_input, training_dir) + + # Create the regularization_dir path + if (not len(util_regularization_images_prompt_input) + or not util_regularization_images_repeat_input > 0): + print( + "Regularization images directory or repeats is missing... not copying regularisation images..." + ) + else: + regularization_dir = os.path.join( + util_training_dir_input, + f"reg/{int(util_regularization_images_repeat_input)}_{util_regularization_images_prompt_input}", + ) + + # Remove folders if they exist + if os.path.exists(regularization_dir): + print(f"Removing existing directory {regularization_dir}...") + shutil.rmtree(regularization_dir) + + # Copy the regularisation images to their respective directories + print( + f"Copy {util_regularization_images_dir_input} to {regularization_dir}..." + ) + shutil.copytree(util_regularization_images_dir_input, + regularization_dir) + + print( + f"Done creating kohya_ss training folder structure at {util_training_dir_input}..." + ) + +def copy_info_to_Directories_tab(training_folder): + img_folder = os.path.join(training_folder, "img") + reg_folder = os.path.join(training_folder, "reg") + model_folder = os.path.join(training_folder, "model") + log_folder = os.path.join(training_folder, "log") + + return img_folder, reg_folder, model_folder, log_folder + css = "" if os.path.exists("./style.css"): @@ -465,19 +542,20 @@ with interface: with gr.Row(): button_open_config = gr.Button("Open 📂", elem_id="open_folder") button_save_config = gr.Button("Save 💾", elem_id="open_folder") - button_save_as_config = gr.Button("Save as... 💾", elem_id="open_folder") + button_save_as_config = gr.Button("Save as... 💾", + elem_id="open_folder") config_file_name = gr.Textbox( - label="", placeholder="type config file path or use buttons..." - ) - config_file_name.change( - remove_doublequote, inputs=[config_file_name], outputs=[config_file_name] - ) + label="", placeholder="type config file path or use buttons...") + config_file_name.change(remove_doublequote, + inputs=[config_file_name], + outputs=[config_file_name]) with gr.Tab("Source model"): # Define the input elements with gr.Row(): pretrained_model_name_or_path_input = gr.Textbox( label="Pretrained model name or path", - placeholder="enter the path to custom model or name of pretrained model", + placeholder= + "enter the path to custom model or name of pretrained model", ) model_list = gr.Dropdown( label="(Optional) Model Quick Pick", @@ -493,9 +571,8 @@ with interface: ) with gr.Row(): v2_input = gr.Checkbox(label="v2", value=True) - v_parameterization_input = gr.Checkbox( - label="v_parameterization", value=False - ) + v_parameterization_input = gr.Checkbox(label="v_parameterization", + value=False) pretrained_model_name_or_path_input.change( remove_doublequote, inputs=[pretrained_model_name_or_path_input], @@ -515,31 +592,40 @@ with interface: with gr.Row(): train_data_dir_input = gr.Textbox( label="Image folder", - placeholder="Directory where the training folders containing the images are located", - ) - train_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") - train_data_dir_input_folder.click( - get_folder_path, outputs=train_data_dir_input + placeholder= + "Directory where the training folders containing the images are located", ) + train_data_dir_input_folder = gr.Button( + "📂", elem_id="open_folder_small") + train_data_dir_input_folder.click(get_folder_path, + outputs=train_data_dir_input) reg_data_dir_input = gr.Textbox( label="Regularisation folder", - placeholder="(Optional) Directory where where the regularization folders containing the images are located", + placeholder= + "(Optional) Directory where where the regularization folders containing the images are located", ) - reg_data_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") - reg_data_dir_input_folder.click(get_folder_path, outputs=reg_data_dir_input) + reg_data_dir_input_folder = gr.Button("📂", + elem_id="open_folder_small") + reg_data_dir_input_folder.click(get_folder_path, + outputs=reg_data_dir_input) with gr.Row(): output_dir_input = gr.Textbox( label="Output directory", placeholder="Directory to output trained model", ) - output_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") - output_dir_input_folder.click(get_folder_path, outputs=output_dir_input) + output_dir_input_folder = gr.Button("📂", + elem_id="open_folder_small") + output_dir_input_folder.click(get_folder_path, + outputs=output_dir_input) logging_dir_input = gr.Textbox( label="Logging directory", - placeholder="Optional: enable logging and output TensorBoard log to this directory", + placeholder= + "Optional: enable logging and output TensorBoard log to this directory", ) - logging_dir_input_folder = gr.Button("📂", elem_id="open_folder_small") - logging_dir_input_folder.click(get_folder_path, outputs=logging_dir_input) + logging_dir_input_folder = gr.Button("📂", + elem_id="open_folder_small") + logging_dir_input_folder.click(get_folder_path, + outputs=logging_dir_input) train_data_dir_input.change( remove_doublequote, inputs=[train_data_dir_input], @@ -550,12 +636,12 @@ with interface: inputs=[reg_data_dir_input], outputs=[reg_data_dir_input], ) - output_dir_input.change( - remove_doublequote, inputs=[output_dir_input], outputs=[output_dir_input] - ) - logging_dir_input.change( - remove_doublequote, inputs=[logging_dir_input], outputs=[logging_dir_input] - ) + output_dir_input.change(remove_doublequote, + inputs=[output_dir_input], + outputs=[output_dir_input]) + logging_dir_input.change(remove_doublequote, + inputs=[logging_dir_input], + outputs=[logging_dir_input]) with gr.Tab("Training parameters"): with gr.Row(): learning_rate_input = gr.Textbox(label="Learning rate", value=1e-6) @@ -573,11 +659,14 @@ with interface: ) lr_warmup_input = gr.Textbox(label="LR warmup", value=0) with gr.Row(): - train_batch_size_input = gr.Slider( - minimum=1, maximum=32, label="Train batch size", value=1, step=1 - ) + train_batch_size_input = gr.Slider(minimum=1, + maximum=32, + label="Train batch size", + value=1, + step=1) epoch_input = gr.Textbox(label="Epoch", value=1) - save_every_n_epochs_input = gr.Textbox(label="Save every N epochs", value=1) + save_every_n_epochs_input = gr.Textbox(label="Save every N epochs", + value=1) with gr.Row(): mixed_precision_input = gr.Dropdown( label="Mixed precision", @@ -606,11 +695,13 @@ with interface: ) with gr.Row(): seed_input = gr.Textbox(label="Seed", value=1234) - max_resolution_input = gr.Textbox(label="Max resolution", value="512,512") + max_resolution_input = gr.Textbox(label="Max resolution", + value="512,512") with gr.Row(): caption_extention_input = gr.Textbox( label="Caption Extension", - placeholder="(Optional) Extension for caption files. default: .caption", + placeholder= + "(Optional) Extension for caption files. default: .caption", ) stop_text_encoder_training_input = gr.Slider( minimum=0, @@ -621,28 +712,108 @@ with interface: ) with gr.Row(): use_safetensors_input = gr.Checkbox( - label="Use safetensor when saving", value=False - ) - enable_bucket_input = gr.Checkbox(label="Enable buckets", value=False) + label="Use safetensor when saving", value=False) + enable_bucket_input = gr.Checkbox(label="Enable buckets", + value=True) cache_latent_input = gr.Checkbox(label="Cache latent", value=True) gradient_checkpointing_input = gr.Checkbox( - label="Gradient checkpointing", value=False - ) + label="Gradient checkpointing", value=False) with gr.Row(): full_fp16_input = gr.Checkbox( - label="Full fp16 training (experimental)", value=False - ) - no_token_padding_input = gr.Checkbox(label="No tokan padding", value=False) - use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam", value=True) - xformers_input = gr.Checkbox(label="USe xformers", value=True) + label="Full fp16 training (experimental)", value=False) + no_token_padding_input = gr.Checkbox(label="No token padding", + value=False) + use_8bit_adam_input = gr.Checkbox(label="Use 8bit adam", + value=True) + xformers_input = gr.Checkbox(label="Use xformers", value=True) with gr.Tab("Model conversion"): convert_to_safetensors_input = gr.Checkbox( - label="Convert to SafeTensors", value=False - ) - convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT", value=False) + label="Convert to SafeTensors", value=False) + convert_to_ckpt_input = gr.Checkbox(label="Convert to CKPT", + value=False) - button_run = gr.Button("Run") + with gr.Tab("Utilities"): + with gr.Tab("Dreambooth folder preparation"): + gr.Markdown( + "This utility will create the required folder structure for the training images and regularisation images that is required for kohys_ss Dreambooth method to properly run." + ) + with gr.Row(): + util_training_images_dir_input = gr.Textbox( + label="Training images", + placeholder="Directory containing the training images", + interactive=True, + ) + button_util_training_images_dir_input = gr.Button( + "📂", elem_id="open_folder_small") + button_util_training_images_dir_input.click( + get_folder_path, outputs=util_training_images_dir_input) + util_training_images_repeat_input = gr.Number( + label="Repeats", + value=40, + interactive=True, + elem_id="number_input") + util_training_images_prompt_input = gr.Textbox( + label="Training images prompt", + placeholder="Prompt for the training images. Eg: asd", + interactive=True, + ) + with gr.Row(): + util_regularization_images_dir_input = gr.Textbox( + label="Regularisation images", + placeholder= + "Directory containing the regularisation images", + interactive=True, + ) + button_util_regularization_images_dir_input = gr.Button( + "📂", elem_id="open_folder_small") + button_util_regularization_images_dir_input.click( + get_folder_path, + outputs=util_regularization_images_dir_input) + util_regularization_images_repeat_input = gr.Number( + label="Repeats", + value=1, + interactive=True, + elem_id="number_input") + util_regularization_images_prompt_input = gr.Textbox( + label="Regularisation images class prompt", + placeholder= + "Prompt for the regularisation images. Eg: person", + interactive=True, + ) + with gr.Row(): + util_training_dir_input = gr.Textbox( + label="Destination training directory", + placeholder= + "Directory where formatted training and regularisation images will be placed", + interactive=True, + ) + button_util_training_dir_input = gr.Button( + "📂", elem_id="open_folder_small") + button_util_training_dir_input.click( + get_folder_path, outputs=util_training_dir_input) + button_prepare_training_data = gr.Button("Prepare training data") + button_prepare_training_data.click( + dreambooth_folder_preparation, + inputs=[ + util_training_images_dir_input, + util_training_images_repeat_input, + util_training_images_prompt_input, + util_regularization_images_dir_input, + util_regularization_images_repeat_input, + util_regularization_images_prompt_input, + util_training_dir_input, + ], + ) + button_copy_info_to_Directories_tab = gr.Button( + "Copy info to Directories Tab") + + button_run = gr.Button("Train model") + + button_copy_info_to_Directories_tab.click( + copy_info_to_Directories_tab, + inputs=[util_training_dir_input], + outputs=[train_data_dir_input, reg_data_dir_input, output_dir_input, logging_dir_input]) button_open_config.click( open_configuration, diff --git a/style.css b/style.css index 6545eec..754673f 100644 --- a/style.css +++ b/style.css @@ -1,5 +1,5 @@ #open_folder_small{ - height: fit-content; + height: auto; min-width: auto; flex-grow: 0; padding-left: 0.25em; @@ -7,8 +7,15 @@ } #open_folder{ - height: fit-content; + height: auto; flex-grow: 0; padding-left: 0.25em; padding-right: 0.25em; +} + +#number_input{ + min-width: min-content; + flex-grow: 0.3; + padding-left: 0.75em; + padding-right: 0.75em; } \ No newline at end of file