diff --git a/.gitignore b/.gitignore index 03b0122..bc80c48 100644 --- a/.gitignore +++ b/.gitignore @@ -1,243 +1,12 @@ -# Kohya_SS Specifics +venv +__pycache__ cudnn_windows .vscode +*.egg-info +build wd14_tagger_model .DS_Store locon gui-user.bat gui-user.ps1 -*.whl* -.idea - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider -# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 - -# User-specific stuff -.idea/**/workspace.xml -.idea/**/tasks.xml -.idea/**/usage.statistics.xml -.idea/**/dictionaries -.idea/**/shelf - -# AWS User-specific -.idea/**/aws.xml - -# Generated files -.idea/**/contentModel.xml - -# Sensitive or high-churn files -.idea/**/dataSources/ -.idea/**/dataSources.ids -.idea/**/dataSources.local.xml -.idea/**/sqlDataSources.xml -.idea/**/dynamic.xml -.idea/**/uiDesigner.xml -.idea/**/dbnavigator.xml - -# Gradle -.idea/**/gradle.xml -.idea/**/libraries - -# Gradle and Maven with auto-import -# When using Gradle or Maven with auto-import, you should exclude module files, -# since they will be recreated, and may cause churn. Uncomment if using -# auto-import. -# .idea/artifacts -# .idea/compiler.xml -# .idea/jarRepositories.xml -# .idea/modules.xml -# .idea/*.iml -# .idea/modules -# *.iml -# *.ipr - -# CMake -cmake-build-*/ - -# Mongo Explorer plugin -.idea/**/mongoSettings.xml - -# File-based project format -*.iws - -# IntelliJ -out/ - -# mpeltonen/sbt-idea plugin -.idea_modules/ - -# JIRA plugin -atlassian-ide-plugin.xml - -# Cursive Clojure plugin -.idea/replstate.xml - -# SonarLint plugin -.idea/sonarlint/ - -# Crashlytics plugin (for Android Studio and IntelliJ) -com_crashlytics_export_strings.xml -crashlytics.properties -crashlytics-build.properties -fabric.properties - -# Editor-based Rest Client -.idea/httpRequests - -# Android studio 3.1+ serialized cache file -.idea/caches/build_file_checksums.ser library/__init__.py diff --git a/dreambooth_gui.py b/dreambooth_gui.py index c2185c1..e93f96e 100644 --- a/dreambooth_gui.py +++ b/dreambooth_gui.py @@ -3,17 +3,14 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import argparse +import gradio as gr import json import math import os -import pathlib import subprocess -import sys - -import gradio as gr - -from library.common_gui_functions import ( +import pathlib +import argparse +from library.common_gui import ( get_folder_path, remove_doublequote, get_file_path, @@ -29,89 +26,89 @@ from library.common_gui_functions import ( gradio_source_model, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, + check_if_model_exist, ) -from library.common_utilities import CommonUtilities -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) -from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 def save_configuration( - save_as, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + save_as, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): # Get list of function parameters and values parameters = list(locals().items()) @@ -125,12 +122,12 @@ def save_configuration( file_path = get_saveasfile_path(file_path) else: print('Save...') - if file_path is None or file_path == '': + if file_path == None or file_path == '': file_path = get_saveasfile_path(file_path) # print(file_path) - if file_path is None or file_path == '': + if file_path == None or file_path == '': return original_file_path # In case a file_path was provided and the user decide to cancel the open action # Return the values of the variables as a dictionary @@ -138,10 +135,10 @@ def save_configuration( name: value for name, value in parameters # locals().items() if name - not in [ - 'file_path', - 'save_as', - ] + not in [ + 'file_path', + 'save_as', + ] } # Extract the destination directory from the file path @@ -159,73 +156,69 @@ def save_configuration( def open_configuration( - ask_for_file, - file_path, - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + ask_for_file, + file_path, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): - print("open_configuration called") - print(f"locals length: {len(locals())}") - print(f"locals: {locals()}") - # Get list of function parameters and values parameters = list(locals().items()) @@ -233,25 +226,18 @@ def open_configuration( original_file_path = file_path - if ask_for_file and file_path is not None: - print(f"File path: {file_path}") - file_path, canceled = get_file_path(file_path=file_path, filedialog_type="json") + if ask_for_file: + file_path = get_file_path(file_path) - if canceled: - return (None,) + (None,) * (len(parameters) - 2) - - if not file_path == '' and file_path is not None: + if not file_path == '' and not file_path == None: + # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) - if CommonUtilities.is_valid_config(my_data): - print('Loading config...') - my_data = update_my_data(my_data) - else: - print("Invalid configuration file.") - my_data = {} - show_message_box("Invalid configuration file.") + print('Loading config...') + # Update values to fix deprecated use_8bit_adam checkbox and set appropriate optimizer if it is set to True + my_data = update_my_data(my_data) else: - file_path = original_file_path + file_path = original_file_path # In case a file_path was provided and the user decide to cancel the open action my_data = {} values = [file_path] @@ -259,92 +245,90 @@ def open_configuration( # Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found if not key in ['ask_for_file', 'file_path']: values.append(my_data.get(key, value)) - # Print the number of returned values - print(f"Returning: {values}") return tuple(values) def train_model( - pretrained_model_name_or_path, - v2, - v_parameterization, - logging_dir, - train_data_dir, - reg_data_dir, - output_dir, - max_resolution, - learning_rate, - lr_scheduler, - lr_warmup, - train_batch_size, - epoch, - save_every_n_epochs, - mixed_precision, - save_precision, - seed, - num_cpu_threads_per_process, - cache_latents, - caption_extension, - enable_bucket, - gradient_checkpointing, - full_fp16, - no_token_padding, - stop_text_encoder_training_pct, - # use_8bit_adam, - xformers, - save_model_as, - shuffle_caption, - save_state, - resume, - prior_loss_weight, - color_aug, - flip_aug, - clip_skip, - vae, - output_name, - max_token_length, - max_train_epochs, - max_data_loader_n_workers, - mem_eff_attn, - gradient_accumulation_steps, - model_list, # Keep this. Yes, it is unused here but required given the common list used - keep_tokens, - persistent_data_loader_workers, - bucket_no_upscale, - random_crop, - bucket_reso_steps, - caption_dropout_every_n_epochs, - caption_dropout_rate, - optimizer, - optimizer_args, - noise_offset, - sample_every_n_steps, - sample_every_n_epochs, - sample_sampler, - sample_prompts, - additional_parameters, - vae_batch_size, - min_snr_gamma, + pretrained_model_name_or_path, + v2, + v_parameterization, + logging_dir, + train_data_dir, + reg_data_dir, + output_dir, + max_resolution, + learning_rate, + lr_scheduler, + lr_warmup, + train_batch_size, + epoch, + save_every_n_epochs, + mixed_precision, + save_precision, + seed, + num_cpu_threads_per_process, + cache_latents, + caption_extension, + enable_bucket, + gradient_checkpointing, + full_fp16, + no_token_padding, + stop_text_encoder_training_pct, + # use_8bit_adam, + xformers, + save_model_as, + shuffle_caption, + save_state, + resume, + prior_loss_weight, + color_aug, + flip_aug, + clip_skip, + vae, + output_name, + max_token_length, + max_train_epochs, + max_data_loader_n_workers, + mem_eff_attn, + gradient_accumulation_steps, + model_list, # Keep this. Yes, it is unused here but required given the common list used + keep_tokens, + persistent_data_loader_workers, + bucket_no_upscale, + random_crop, + bucket_reso_steps, + caption_dropout_every_n_epochs, + caption_dropout_rate, + optimizer, + optimizer_args, + noise_offset, + sample_every_n_steps, + sample_every_n_epochs, + sample_sampler, + sample_prompts, + additional_parameters, + vae_batch_size, + min_snr_gamma, ): if pretrained_model_name_or_path == '': - show_message_box('Source model information is missing') + msgbox('Source model information is missing') return if train_data_dir == '': - show_message_box('Image folder path is missing') + msgbox('Image folder path is missing') return if not os.path.exists(train_data_dir): - show_message_box('Image folder does not exist') + msgbox('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - show_message_box('Regularisation folder does not exist') + msgbox('Regularisation folder does not exist') return if output_dir == '': - show_message_box('Output folder path is missing') + msgbox('Output folder path is missing') return if check_if_model_exist(output_name, output_dir, save_model_as): @@ -355,7 +339,7 @@ def train_model( f for f in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, f)) - and not f.startswith('.') + and not f.startswith('.') ] # Check if subfolders are present. If not let the user know and return @@ -387,11 +371,11 @@ def train_model( [ f for f, lower_f in ( - (file, file.lower()) - for file in os.listdir( - os.path.join(train_data_dir, folder) - ) - ) + (file, file.lower()) + for file in os.listdir( + os.path.join(train_data_dir, folder) + ) + ) if lower_f.endswith(('.jpg', '.jpeg', '.png', '.webp')) ] ) @@ -562,10 +546,10 @@ def train_model( def dreambooth_tab( - train_data_dir=gr.Textbox(), - reg_data_dir=gr.Textbox(), - output_dir=gr.Textbox(), - logging_dir=gr.Textbox(), + train_data_dir=gr.Textbox(), + reg_data_dir=gr.Textbox(), + output_dir=gr.Textbox(), + logging_dir=gr.Textbox(), ): dummy_db_true = gr.Label(value=True, visible=False) dummy_db_false = gr.Label(value=False, visible=False) @@ -848,20 +832,19 @@ def dreambooth_tab( ] button_open_config.click( - lambda *_args, **kwargs: open_configuration(*_args, **kwargs), + open_configuration, inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: (print("Lambda called"), open_configuration(*args, **kwargs)), - inputs=[dummy_db_true, config_file_name] + settings_list, + open_configuration, + inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) - # Print the number of expected outputs - print(f"Number of expected outputs: {len([config_file_name] + settings_list)}") + button_save_config.click( save_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, diff --git a/fine_tune.py b/fine_tune.py index ad298fd..637a729 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -5,20 +5,22 @@ import argparse import gc import math import os +import toml from multiprocessing import Value +from tqdm import tqdm import torch from accelerate.utils import set_seed +import diffusers from diffusers import DDPMScheduler -from tqdm import tqdm -import library.config_ml_util as config_util -import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.config_ml_util import ( +import library.config_util as config_util +from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/finetune_gui.py b/finetune_gui.py index 4c6e5ce..b085928 100644 --- a/finetune_gui.py +++ b/finetune_gui.py @@ -1,18 +1,17 @@ -import argparse +import gradio as gr import json import math import os -import pathlib import subprocess - -import gradio as gr - -from library.common_gui_functions import ( +import pathlib +import argparse +from library.common_gui import ( get_folder_path, get_file_path, get_saveasfile_path, save_inference_file, gradio_advanced_training, + run_cmd_advanced_training, gradio_training, run_cmd_advanced_training, gradio_config, @@ -23,13 +22,13 @@ from library.common_gui_functions import ( update_my_data, check_if_model_exist, ) -from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -234,7 +233,7 @@ def open_configuration( if ask_for_file: file_path = get_file_path(file_path) - if not file_path == '' and file_path is not None: + if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -800,14 +799,14 @@ def finetune_tab(): button_run.click(train_model, inputs=settings_list) button_open_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/kohya_gui.py b/kohya_gui.py index 732cd78..f8e0d8c 100644 --- a/kohya_gui.py +++ b/kohya_gui.py @@ -1,18 +1,17 @@ -import argparse -import os -from pathlib import Path - import gradio as gr - +import os +import argparse from dreambooth_gui import dreambooth_tab from finetune_gui import finetune_tab +from textual_inversion_gui import ti_tab +from library.utilities import utilities_tab from library.extract_lora_gui import gradio_extract_lora_tab from library.extract_lycoris_locon_gui import gradio_extract_lycoris_locon_tab from library.merge_lora_gui import gradio_merge_lora_tab from library.resize_lora_gui import gradio_resize_lora_tab -from library.utilities import utilities_tab from lora_gui import lora_tab -from textual_inversion_gui import ti_tab + + def UI(**kwargs): css = '' diff --git a/library/basic_caption_gui.py b/library/basic_caption_gui.py index 0fa7ba6..b2d208d 100644 --- a/library/basic_caption_gui.py +++ b/library/basic_caption_gui.py @@ -1,9 +1,8 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import get_folder_path, add_pre_postfix, find_replace +from easygui import msgbox +import subprocess +from .common_gui import get_folder_path, add_pre_postfix, find_replace +import os def caption_images( @@ -18,11 +17,11 @@ def caption_images( ): # Check for images_dir if not images_dir: - show_message_box('Image folder is missing...') + msgbox('Image folder is missing...') return if not caption_ext: - show_message_box('Please provide an extension for the caption files.') + msgbox('Please provide an extension for the caption files.') return if caption_text: @@ -61,7 +60,7 @@ def caption_images( ) else: if prefix or postfix: - show_message_box( + msgbox( 'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...' ) diff --git a/library/blip_caption_gui.py b/library/blip_caption_gui.py index ffe087b..2e0081d 100644 --- a/library/blip_caption_gui.py +++ b/library/blip_caption_gui.py @@ -1,9 +1,8 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import get_folder_path, add_pre_postfix +from easygui import msgbox +import subprocess +import os +from .common_gui import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -22,16 +21,16 @@ def caption_images( ): # Check for caption_text_input # if caption_text_input == "": - # show_message_box("Caption text is missing...") + # msgbox("Caption text is missing...") # return # Check for images_dir_input if train_data_dir == '': - show_message_box('Image folder is missing...') + msgbox('Image folder is missing...') return if caption_file_ext == '': - show_message_box('Please provide an extension for the caption files.') + msgbox('Please provide an extension for the caption files.') return print(f'Captioning files in {train_data_dir}...') diff --git a/library/common_gui_functions.py b/library/common_gui.py similarity index 86% rename from library/common_gui_functions.py rename to library/common_gui.py index f809de5..b08ac9c 100644 --- a/library/common_gui_functions.py +++ b/library/common_gui.py @@ -1,19 +1,14 @@ -import os -import shutil -import subprocess -from contextlib import contextmanager -import tkinter as tk from tkinter import filedialog, Tk - -import easygui +from easygui import msgbox +import os import gradio as gr - -from library.common_utilities import CommonUtilities +import easygui +import shutil folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 # define a list of substrings to search for v2 base models V2_BASE_MODELS = [ @@ -39,41 +34,6 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS FILE_ENV_EXCLUSION = ['COLAB_GPU', 'RUNPOD_POD_ID'] -@contextmanager -def tk_context(): - root = tk.Tk() - root.withdraw() - try: - yield root - finally: - root.destroy() - - -def open_file_dialog(initial_dir, initial_file, file_types="all"): - current_directory = os.path.dirname(os.path.abspath(__file__)) - - args = ["python", f"{current_directory}/gui_subprocesses.py", "file_dialog"] - if initial_dir: - args.append(initial_dir) - if initial_file: - args.append(initial_file) - if file_types: - args.append(file_types) - - file_path = subprocess.check_output(args).decode("utf-8").strip() - return file_path - - -def show_message_box(message, title=""): - current_directory = os.path.dirname(os.path.abspath(__file__)) - - args = ["python", f"{current_directory}/gui_subprocesses.py", "msgbox", message] - if title: - args.append(title) - - subprocess.run(args) - - def check_if_model_exist(output_name, output_dir, save_model_as): if save_model_as in ['diffusers', 'diffusers_safetendors']: ckpt_folder = os.path.join(output_dir, output_name) @@ -127,8 +87,8 @@ def update_my_data(my_data): # Update model save choices due to changes for LoRA and TI training if ( - (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) - and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] + (my_data.get('LoRA_type') or my_data.get('num_vectors_per_token')) + and my_data.get('save_model_as') not in ['safetensors', 'ckpt'] ): message = ( 'Updating save_model_as to safetensors because the current value in the config file is no longer applicable to {}' @@ -142,6 +102,11 @@ def update_my_data(my_data): return my_data +def get_dir_and_file(file_path): + dir_path, file_name = os.path.split(file_path) + return (dir_path, file_name) + + # def has_ext_files(directory, extension): # # Iterate through all the files in the directory # for file in os.listdir(directory): @@ -152,38 +117,67 @@ def update_my_data(my_data): # return False -def get_file_path(file_path, initial_dir=None, initial_file=None, filedialog_type="lora"): - file_extension = os.path.splitext(file_path)[-1].lower() +def get_file_path( + file_path='', default_extension='.json', extension_name='Config files' +): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') - # Find the appropriate filedialog_type based on the file extension - for key, extensions in CommonUtilities.file_filters.items(): - if file_extension in extensions: - filedialog_type = key - break + initial_dir, initial_file = get_dir_and_file(file_path) - current_file_path = file_path + # Create a hidden Tkinter root window + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() - initial_dir, initial_file = os.path.split(file_path) - result = open_file_dialog(initial_dir=initial_dir, initial_file=initial_file, file_types=filedialog_type) - file_path, canceled = result[:2] - return file_path, canceled + # Show the open file dialog and get the selected file path + file_path = filedialog.askopenfilename( + filetypes=( + (extension_name, f'*{default_extension}'), + ('All files', '*.*'), + ), + defaultextension=default_extension, + initialfile=initial_file, + initialdir=initial_dir, + ) + + # Destroy the hidden root window + root.destroy() + + # If no file is selected, use the current file path + if not file_path: + file_path = current_file_path + current_file_path = file_path + # print(f'current file path: {current_file_path}') + + return file_path def get_any_file_path(file_path=''): - current_file_path = file_path - # print(f'current file path: {current_file_path}') + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') - initial_dir, initial_file = os.path.split(file_path) - file_path = open_file_dialog(initial_dir, initial_file, "all") + initial_dir, initial_file = get_dir_and_file(file_path) - if file_path == '': - file_path = current_file_path + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + file_path = filedialog.askopenfilename( + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + if file_path == '': + file_path = current_file_path return file_path def remove_doublequote(file_path): - if file_path is not None: + if file_path != None: file_path = file_path.replace('"', '') return file_path @@ -202,37 +196,62 @@ def remove_doublequote(file_path): # ) -def get_folder_path(folder_path='', filedialog_type="directory"): - current_folder_path = folder_path +def get_folder_path(folder_path=''): + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_folder_path = folder_path - initial_dir, initial_file = os.path.split(folder_path) - file_path = open_file_dialog(initial_dir, initial_file, filedialog_type) + initial_dir, initial_file = get_dir_and_file(folder_path) - if folder_path == '': - folder_path = current_folder_path + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + folder_path = filedialog.askdirectory(initialdir=initial_dir) + root.destroy() + + if folder_path == '': + folder_path = current_folder_path return folder_path def get_saveasfile_path( - file_path='', filedialog_type="json" + file_path='', defaultextension='.json', extension_name='Config files' ): - current_file_path = file_path + if not any(var in os.environ for var in FILE_ENV_EXCLUSION): + current_file_path = file_path + # print(f'current file path: {current_file_path}') - initial_dir, initial_file = os.path.split(file_path) - save_file_path = save_file_dialog(initial_dir, initial_file, filedialog_type) + initial_dir, initial_file = get_dir_and_file(file_path) - if save_file_path is None: - file_path = current_file_path - else: - print(save_file_path.name) - file_path = save_file_path.name + root = Tk() + root.wm_attributes('-topmost', 1) + root.withdraw() + save_file_path = filedialog.asksaveasfile( + filetypes=( + (f'{extension_name}', f'{defaultextension}'), + ('All files', '*'), + ), + defaultextension=defaultextension, + initialdir=initial_dir, + initialfile=initial_file, + ) + root.destroy() + + # print(save_file_path) + + if save_file_path == None: + file_path = current_file_path + else: + print(save_file_path.name) + file_path = save_file_path.name + + # print(file_path) return file_path def get_saveasfilename_path( - file_path='', extensions='*', extension_name='Config files' + file_path='', extensions='*', extension_name='Config files' ): if not any(var in os.environ for var in FILE_ENV_EXCLUSION): current_file_path = file_path @@ -261,10 +280,10 @@ def get_saveasfilename_path( def add_pre_postfix( - folder: str = '', - prefix: str = '', - postfix: str = '', - caption_file_ext: str = '.caption', + folder: str = '', + prefix: str = '', + postfix: str = '', + caption_file_ext: str = '.caption', ) -> None: """ Add prefix and/or postfix to the content of caption files within a folder. @@ -324,10 +343,10 @@ def has_ext_files(folder_path: str, file_extension: str) -> bool: def find_replace( - folder_path: str = '', - caption_file_ext: str = '.caption', - search_text: str = '', - replace_text: str = '', + folder_path: str = '', + caption_file_ext: str = '.caption', + search_text: str = '', + replace_text: str = '', ) -> None: """ Find and replace text in caption files within a folder. @@ -341,7 +360,7 @@ def find_replace( print('Running caption find/replace') if not has_ext_files(folder_path, caption_file_ext): - show_message_box( + msgbox( f'No files with extension {caption_file_ext} were found in {folder_path}...' ) return @@ -355,7 +374,7 @@ def find_replace( for caption_file in caption_files: with open( - os.path.join(folder_path, caption_file), 'r', errors='ignore' + os.path.join(folder_path, caption_file), 'r', errors='ignore' ) as f: content = f.read() @@ -367,7 +386,7 @@ def find_replace( def color_aug_changed(color_aug): if color_aug: - show_message_box( + msgbox( 'Disabling "Cache latent" because "Color augmentation" has been selected...' ) return gr.Checkbox.update(value=False, interactive=False) @@ -408,7 +427,7 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name): def set_pretrained_model_name_or_path_input( - model_list, pretrained_model_name_or_path, v2, v_parameterization + model_list, pretrained_model_name_or_path, v2, v_parameterization ): # check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list if str(model_list) in V2_BASE_MODELS: @@ -433,9 +452,9 @@ def set_pretrained_model_name_or_path_input( if model_list == 'custom': if ( - str(pretrained_model_name_or_path) in V1_MODELS - or str(pretrained_model_name_or_path) in V2_BASE_MODELS - or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS + str(pretrained_model_name_or_path) in V1_MODELS + or str(pretrained_model_name_or_path) in V2_BASE_MODELS + or str(pretrained_model_name_or_path) in V_PARAMETERIZATION_MODELS ): pretrained_model_name_or_path = '' v2 = False @@ -462,11 +481,12 @@ def set_v2_checkbox(model_list, v2, v_parameterization): def set_model_list( - model_list, - pretrained_model_name_or_path, - v2, - v_parameterization, + model_list, + pretrained_model_name_or_path, + v2, + v_parameterization, ): + if not pretrained_model_name_or_path in ALL_PRESET_MODELS: model_list = 'custom' else: @@ -509,7 +529,7 @@ def gradio_config(): def get_pretrained_model_name_or_path_file( - model_list, pretrained_model_name_or_path + model_list, pretrained_model_name_or_path ): pretrained_model_name_or_path = get_any_file_path( pretrained_model_name_or_path @@ -517,13 +537,13 @@ def get_pretrained_model_name_or_path_file( set_model_list(model_list, pretrained_model_name_or_path) -def gradio_source_model(save_model_as_choices=[ - 'same as source model', - 'ckpt', - 'diffusers', - 'diffusers_safetensors', - 'safetensors', -]): +def gradio_source_model(save_model_as_choices = [ + 'same as source model', + 'ckpt', + 'diffusers', + 'diffusers_safetensors', + 'safetensors', + ]): with gr.Tab('Source model'): # Define the input elements with gr.Row(): @@ -628,9 +648,9 @@ def gradio_source_model(save_model_as_choices=[ def gradio_training( - learning_rate_value='1e-6', - lr_scheduler_value='constant', - lr_warmup_value='0', + learning_rate_value='1e-6', + lr_scheduler_value='constant', + lr_warmup_value='0', ): with gr.Row(): train_batch_size = gr.Slider( @@ -820,7 +840,7 @@ def gradio_advanced_training(): xformers = gr.Checkbox(label='Use xformers', value=True) color_aug = gr.Checkbox(label='Color augmentation', value=False) flip_aug = gr.Checkbox(label='Flip augmentation', value=False) - min_snr_gamma = gr.Slider(label='Min SNR gamma', value=0, minimum=0, maximum=20, step=1) + min_snr_gamma = gr.Slider(label='Min SNR gamma', value = 0, minimum=0, maximum=20, step=1) with gr.Row(): bucket_no_upscale = gr.Checkbox( label="Don't upscale bucket resolution", value=True diff --git a/library/common_utilities.py b/library/common_utilities.py deleted file mode 100644 index 737828b..0000000 --- a/library/common_utilities.py +++ /dev/null @@ -1,24 +0,0 @@ -class CommonUtilities: - file_filters = { - "all": [("All files", "*.*")], - "video": [("Video files", "*.mp4;*.avi;*.mkv;*.mov;*.flv;*.wmv")], - "images": [("Image files", "*.jpg;*.jpeg;*.png;*.bmp;*.gif;*.tiff")], - "json": [("JSON files", "*.json")], - "lora": [("LoRa files", "*.ckpt;*.pt;*.safetensors")], - "directory": [], - } - - def is_valid_config(self, data): - # Check if the data is a dictionary - if not isinstance(data, dict): - return False - - # Add checks for expected keys and valid values - # For example, check if 'use_8bit_adam' is a boolean - if "use_8bit_adam" in data and not isinstance(data["use_8bit_adam"], bool): - return False - - # Add more checks for other keys as needed - - # If all checks pass, return True - return True diff --git a/library/config_ml_util.py b/library/config_util.py similarity index 100% rename from library/config_ml_util.py rename to library/config_util.py index af35896..97bbb4a 100644 --- a/library/config_ml_util.py +++ b/library/config_util.py @@ -1,13 +1,13 @@ import argparse -import functools -import json -import random from dataclasses import ( asdict, dataclass, ) -from pathlib import Path +import functools +import random from textwrap import dedent, indent +import json +from pathlib import Path # from toolz import curry from typing import ( List, @@ -19,7 +19,6 @@ from typing import ( import toml import voluptuous -from transformers import CLIPTokenizer from voluptuous import ( Any, ExactSequence, @@ -28,6 +27,7 @@ from voluptuous import ( Required, Schema, ) +from transformers import CLIPTokenizer from . import train_util from .train_util import ( diff --git a/library/convert_model_gui.py b/library/convert_model_gui.py index 580833d..aaa39b8 100644 --- a/library/convert_model_gui.py +++ b/library/convert_model_gui.py @@ -1,29 +1,28 @@ +import gradio as gr +from easygui import msgbox +import subprocess import os import shutil -import subprocess - -import gradio as gr - -from .common_gui_functions import get_folder_path, get_file_path +from .common_gui import get_folder_path, get_file_path folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def convert_model( - source_model_input, - source_model_type, - target_model_folder_input, - target_model_name_input, - target_model_type, - target_save_precision_type, + source_model_input, + source_model_type, + target_model_folder_input, + target_model_name_input, + target_model_type, + target_save_precision_type, ): # Check for caption_text_input if source_model_type == '': - show_message_box('Invalid source model type') + msgbox('Invalid source model type') return # Check if source model exist @@ -32,14 +31,14 @@ def convert_model( elif os.path.isdir(source_model_input): print('The provided model is a folder') else: - show_message_box('The provided source model is neither a file nor a folder') + msgbox('The provided source model is neither a file nor a folder') return # Check if source model exist if os.path.isdir(target_model_folder_input): print('The provided model folder exist') else: - show_message_box('The provided target folder does not exist') + msgbox('The provided target folder does not exist') return run_cmd = f'{PYTHON} "tools/convert_diffusers20_original_sd.py"' @@ -61,8 +60,8 @@ def convert_model( run_cmd += f' --{target_save_precision_type}' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): run_cmd += f' --reference_model="{source_model_type}"' @@ -72,8 +71,8 @@ def convert_model( run_cmd += f' "{source_model_input}"' if ( - target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): target_model_path = os.path.join( target_model_folder_input, target_model_name_input @@ -95,8 +94,8 @@ def convert_model( subprocess.run(run_cmd) if ( - not target_model_type == 'diffuser' - or target_model_type == 'diffuser_safetensors' + not target_model_type == 'diffuser' + or target_model_type == 'diffuser_safetensors' ): v2_models = [ @@ -180,7 +179,7 @@ def gradio_convert_model_tab(): document_symbol, elem_id='open_folder_small' ) button_source_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[source_model_input], outputs=source_model_input, show_progress=False, diff --git a/library/dataset_balancing_gui.py b/library/dataset_balancing_gui.py index f21319b..2e6bc98 100644 --- a/library/dataset_balancing_gui.py +++ b/library/dataset_balancing_gui.py @@ -1,11 +1,8 @@ import os import re - import gradio as gr -from easygui import boolbox - -from .common_gui_functions import get_folder_path - +from easygui import msgbox, boolbox +from .common_gui import get_folder_path # def select_folder(): # # Open a file dialog to select a directory @@ -19,14 +16,14 @@ def dataset_balancing(concept_repeats, folder, insecure): if not concept_repeats > 0: # Display an error message if the total number of repeats is not a valid integer - show_message_box('Please enter a valid integer for the total number of repeats.') + msgbox('Please enter a valid integer for the total number of repeats.') return concept_repeats = int(concept_repeats) # Check if folder exist if folder == '' or not os.path.isdir(folder): - show_message_box('Please enter a valid folder for balancing.') + msgbox('Please enter a valid folder for balancing.') return pattern = re.compile(r'^\d+_.+$') @@ -88,7 +85,7 @@ def dataset_balancing(concept_repeats, folder, insecure): f'Skipping folder {subdir} because it does not match kohya_ss expected syntax...' ) - show_message_box('Dataset balancing completed...') + msgbox('Dataset balancing completed...') def warning(insecure): diff --git a/library/dreambooth_folder_creation_gui.py b/library/dreambooth_folder_creation_gui.py index e554930..b5d5ff4 100644 --- a/library/dreambooth_folder_creation_gui.py +++ b/library/dreambooth_folder_creation_gui.py @@ -1,9 +1,8 @@ -import os -import shutil - import gradio as gr - -from .common_gui_functions import get_folder_path +from easygui import diropenbox, msgbox +from .common_gui import get_folder_path +import shutil +import os def copy_info_to_Folders_tab(training_folder): @@ -40,12 +39,12 @@ def dreambooth_folder_preparation( # Check for instance prompt if util_instance_prompt_input == '': - show_message_box('Instance prompt missing...') + msgbox('Instance prompt missing...') return # Check for class prompt if util_class_prompt_input == '': - show_message_box('Class prompt missing...') + msgbox('Class prompt missing...') return # Create the training_dir path diff --git a/library/extract_lora_gui.py b/library/extract_lora_gui.py index b54e4ff..53292d3 100644 --- a/library/extract_lora_gui.py +++ b/library/extract_lora_gui.py @@ -1,49 +1,50 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import ( - get_file_path, get_saveasfile_path, +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, ) folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 -document_symbol = '\U0001F4C4' # 📄 +document_symbol = '\U0001F4C4' # 📄 PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' def extract_lora( - model_tuned, - model_org, - save_to, - save_precision, - dim, - v2, - conv_dim, - device, + model_tuned, + model_org, + save_to, + save_precision, + dim, + v2, + conv_dim, + device, ): # Check for caption_text_input if model_tuned == '': - show_message_box('Invalid finetuned model file') + msgbox('Invalid finetuned model file') return if model_org == '': - show_message_box('Invalid base model file') + msgbox('Invalid base model file') return # Check if source model exist if not os.path.isfile(model_tuned): - show_message_box('The provided finetuned model is not a file') + msgbox('The provided finetuned model is not a file') return if not os.path.isfile(model_org): - show_message_box('The provided base model is not a file') + msgbox('The provided base model is not a file') return run_cmd = ( - f'{PYTHON} "{os.path.join("networks", "extract_lora_from_models.py")}"' + f'{PYTHON} "{os.path.join("networks","extract_lora_from_models.py")}"' ) run_cmd += f' --save_precision {save_precision}' run_cmd += f' --save_to "{save_to}"' @@ -90,7 +91,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_tuned_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[model_tuned, model_ext, model_ext_name], outputs=model_tuned, show_progress=False, @@ -105,8 +106,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_model_org_file.click( - lambda input1, input2, input3, *args, **kwargs: - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[model_org, model_ext, model_ext_name], outputs=model_org, show_progress=False, @@ -121,7 +121,7 @@ def gradio_extract_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfile_path, + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/extract_lycoris_locon_gui.py b/library/extract_lycoris_locon_gui.py index e8a620b..13575bb 100644 --- a/library/extract_lycoris_locon_gui.py +++ b/library/extract_lycoris_locon_gui.py @@ -1,10 +1,11 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import ( - get_file_path, get_saveasfile_path, +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, ) folder_symbol = '\U0001f4c2' # 📂 @@ -35,20 +36,20 @@ def extract_lycoris_locon( ): # Check for caption_text_input if db_model == '': - show_message_box('Invalid finetuned model file') + msgbox('Invalid finetuned model file') return if base_model == '': - show_message_box('Invalid base model file') + msgbox('Invalid base model file') return # Check if source model exist if not os.path.isfile(db_model): - show_message_box('The provided finetuned model is not a file') + msgbox('The provided finetuned model is not a file') return if not os.path.isfile(base_model): - show_message_box('The provided base model is not a file') + msgbox('The provided base model is not a file') return run_cmd = f'{PYTHON} "{os.path.join("tools","lycoris_locon_extract.py")}"' @@ -136,8 +137,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_db_model_file.click( - lambda input1, input2, input3, *args, **kwargs: - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[db_model, model_ext, model_ext_name], outputs=db_model, show_progress=False, @@ -152,7 +152,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_base_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[base_model, model_ext, model_ext_name], outputs=base_model, show_progress=False, @@ -167,7 +167,7 @@ def gradio_extract_lycoris_locon_tab(): folder_symbol, elem_id='open_folder_small' ) button_output_name.click( - get_saveasfile_path, + get_saveasfilename_path, inputs=[output_name, lora_ext, lora_ext_name], outputs=output_name, show_progress=False, diff --git a/library/git_caption_gui.py b/library/git_caption_gui.py index a4cc6d0..9aaf3d9 100644 --- a/library/git_caption_gui.py +++ b/library/git_caption_gui.py @@ -1,9 +1,8 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import get_folder_path, add_pre_postfix +from easygui import msgbox +import subprocess +import os +from .common_gui import get_folder_path, add_pre_postfix PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' @@ -20,11 +19,11 @@ def caption_images( ): # Check for images_dir_input if train_data_dir == '': - show_message_box('Image folder is missing...') + msgbox('Image folder is missing...') return if caption_ext == '': - show_message_box('Please provide an extension for the caption files.') + msgbox('Please provide an extension for the caption files.') return print(f'GIT captioning files in {train_data_dir}...') diff --git a/library/gui_subprocesses.py b/library/gui_subprocesses.py deleted file mode 100644 index 2e45b8e..0000000 --- a/library/gui_subprocesses.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -import pathlib -import sys -import tkinter as tk -from tkinter import filedialog, messagebox - -from library.common_gui_functions import tk_context -from library.common_utilities import CommonUtilities - - -class TkGui: - def __init__(self): - self.file_types = None - - def open_file_dialog(self, initial_dir=None, initial_file=None, file_types="all"): - with tk_context(): - self.file_types = file_types - if self.file_types in CommonUtilities.file_filters: - filters = CommonUtilities.file_filters[self.file_types] - else: - filters = CommonUtilities.file_filters["all"] - - if self.file_types == "directory": - result = filedialog.askdirectory(initialdir=initial_dir) - else: - result = filedialog.askopenfilename(initialdir=initial_dir, initialfile=initial_file, filetypes=filters) - - # Return a tuple (file_path, canceled) - # file_path: the selected file path or an empty string if no file is selected - # canceled: True if the user pressed the cancel button, False otherwise - return result, result == "" - - def save_file_dialog(self, initial_dir, initial_file, file_types="all"): - self.file_types = file_types - - # Use the tk_context function with the 'with' statement - with tk_context(): - if self.file_types in CommonUtilities.file_filters: - filters = CommonUtilities.file_filters[self.file_types] - else: - filters = CommonUtilities.file_filters["all"] - - save_file_path = filedialog.asksaveasfilename(initialdir=initial_dir, initialfile=initial_file, - filetypes=filters, defaultextension=".safetensors") - - return save_file_path - - def show_message_box(_message, _title="Message", _level="info"): - with tk_context(): - message_type = { - "warning": messagebox.showwarning, - "error": messagebox.showerror, - "info": messagebox.showinfo, - "question": messagebox.askquestion, - "okcancel": messagebox.askokcancel, - "retrycancel": messagebox.askretrycancel, - "yesno": messagebox.askyesno, - "yesnocancel": messagebox.askyesnocancel - } - - if _level in message_type: - message_type[_level](_title, _message) - else: - messagebox.showinfo(_title, _message) - - -if __name__ == '__main__': - try: - mode = sys.argv[1] - - if mode == 'file_dialog': - starting_dir = sys.argv[2] if len(sys.argv) > 2 else None - starting_file = sys.argv[3] if len(sys.argv) > 3 else None - file_class = sys.argv[4] if len(sys.argv) > 4 else None # Update this to sys.argv[4] - gui = TkGui() - file_path = gui.open_file_dialog(starting_dir, starting_file, file_class) - print(file_path) # Make sure to print the result - - elif mode == 'msgbox': - message = sys.argv[2] - title = sys.argv[3] if len(sys.argv) > 3 else "" - gui = TkGui() - gui.show_message_box(message, title) - except Exception as e: - print(f"Error: {e}", file=sys.stderr) - sys.exit(1) diff --git a/library/merge_lora_gui.py b/library/merge_lora_gui.py index e2bd428..21cd16a 100644 --- a/library/merge_lora_gui.py +++ b/library/merge_lora_gui.py @@ -1,10 +1,11 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import ( - get_file_path, get_saveasfile_path, +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, ) folder_symbol = '\U0001f4c2' # 📂 @@ -24,20 +25,20 @@ def merge_lora( ): # Check for caption_text_input if lora_a_model == '': - show_message_box('Invalid model A file') + msgbox('Invalid model A file') return if lora_b_model == '': - show_message_box('Invalid model B file') + msgbox('Invalid model B file') return # Check if source model exist if not os.path.isfile(lora_a_model): - show_message_box('The provided model A is not a file') + msgbox('The provided model A is not a file') return if not os.path.isfile(lora_b_model): - show_message_box('The provided model B is not a file') + msgbox('The provided model B is not a file') return ratio_a = ratio @@ -81,7 +82,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -96,7 +97,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, @@ -121,7 +122,7 @@ def gradio_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfile_path, + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/resize_lora_gui.py b/library/resize_lora_gui.py index fc7ae12..ecf1b45 100644 --- a/library/resize_lora_gui.py +++ b/library/resize_lora_gui.py @@ -1,9 +1,8 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import get_file_path, get_saveasfile_path +from easygui import msgbox +import subprocess +import os +from .common_gui import get_saveasfilename_path, get_file_path PYTHON = 'python3' if os.name == 'posix' else './venv/Scripts/python.exe' folder_symbol = '\U0001f4c2' # 📂 @@ -24,24 +23,24 @@ def resize_lora( ): # Check for caption_text_input if model == '': - show_message_box('Invalid model file') + msgbox('Invalid model file') return # Check if source model exist if not os.path.isfile(model): - show_message_box('The provided model is not a file') + msgbox('The provided model is not a file') return if dynamic_method == 'sv_ratio': if float(dynamic_param) < 2: - show_message_box( + msgbox( f'Dynamic parameter for {dynamic_method} need to be 2 or greater...' ) return if dynamic_method == 'sv_fro' or dynamic_method == 'sv_cumulative': if float(dynamic_param) < 0 or float(dynamic_param) > 1: - show_message_box( + msgbox( f'Dynamic parameter for {dynamic_method} need to be between 0 and 1...' ) return @@ -96,7 +95,7 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[model, lora_ext, lora_ext_name], outputs=model, show_progress=False, @@ -135,7 +134,7 @@ def gradio_resize_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfile_path, + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/sampler_gui.py b/library/sampler_gui.py index c1ba146..ce95313 100644 --- a/library/sampler_gui.py +++ b/library/sampler_gui.py @@ -1,6 +1,7 @@ import tempfile import os import gradio as gr +from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/library/svd_merge_lora_gui.py b/library/svd_merge_lora_gui.py index c8a2fe6..042be2e 100644 --- a/library/svd_merge_lora_gui.py +++ b/library/svd_merge_lora_gui.py @@ -1,10 +1,11 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import ( - get_file_path, get_saveasfile_path, +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, + get_file_path, ) folder_symbol = '\U0001f4c2' # 📂 @@ -27,20 +28,20 @@ def svd_merge_lora( ): # Check for caption_text_input if lora_a_model == '': - show_message_box('Invalid model A file') + msgbox('Invalid model A file') return if lora_b_model == '': - show_message_box('Invalid model B file') + msgbox('Invalid model B file') return # Check if source model exist if not os.path.isfile(lora_a_model): - show_message_box('The provided model A is not a file') + msgbox('The provided model A is not a file') return if not os.path.isfile(lora_b_model): - show_message_box('The provided model B is not a file') + msgbox('The provided model B is not a file') return ratio_a = ratio @@ -87,7 +88,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_a_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[lora_a_model, lora_ext, lora_ext_name], outputs=lora_a_model, show_progress=False, @@ -102,7 +103,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_b_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[lora_b_model, lora_ext, lora_ext_name], outputs=lora_b_model, show_progress=False, @@ -143,7 +144,7 @@ def gradio_svd_merge_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_save_to.click( - get_saveasfile_path, + get_saveasfilename_path, inputs=[save_to, lora_ext, lora_ext_name], outputs=save_to, show_progress=False, diff --git a/library/tensorboard_gui.py b/library/tensorboard_gui.py index f2dd74f..d08a02d 100644 --- a/library/tensorboard_gui.py +++ b/library/tensorboard_gui.py @@ -1,9 +1,9 @@ import os +import gradio as gr +from easygui import msgbox import subprocess import time -import gradio as gr - tensorboard_proc = None # I know... bad but heh TENSORBOARD = 'tensorboard' if os.name == 'posix' else 'tensorboard.exe' @@ -13,7 +13,7 @@ def start_tensorboard(logging_dir): if not os.listdir(logging_dir): print('Error: log folder is empty') - show_message_box(msg='Error: log folder is empty') + msgbox(msg='Error: log folder is empty') return run_cmd = [f'{TENSORBOARD}', '--logdir', f'{logging_dir}'] diff --git a/library/verify_lora_gui.py b/library/verify_lora_gui.py index 52626dd..a7a0bf9 100644 --- a/library/verify_lora_gui.py +++ b/library/verify_lora_gui.py @@ -1,9 +1,10 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import ( +from easygui import msgbox +import subprocess +import os +from .common_gui import ( + get_saveasfilename_path, + get_any_file_path, get_file_path, ) @@ -19,12 +20,12 @@ def verify_lora( ): # verify for caption_text_input if lora_model == '': - show_message_box('Invalid model A file') + msgbox('Invalid model A file') return # verify if source model exist if not os.path.isfile(lora_model): - show_message_box('The provided model A is not a file') + msgbox('The provided model A is not a file') return run_cmd = [ @@ -68,7 +69,7 @@ def gradio_verify_lora_tab(): folder_symbol, elem_id='open_folder_small' ) button_lora_model_file.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, inputs=[lora_model, lora_ext, lora_ext_name], outputs=lora_model, show_progress=False, diff --git a/library/wd14_caption_gui.py b/library/wd14_caption_gui.py index 18103ef..1970849 100644 --- a/library/wd14_caption_gui.py +++ b/library/wd14_caption_gui.py @@ -1,9 +1,8 @@ -import os -import subprocess - import gradio as gr - -from .common_gui_functions import get_folder_path +from easygui import msgbox +import subprocess +from .common_gui import get_folder_path +import os def replace_underscore_with_space(folder_path, file_extension): @@ -21,16 +20,16 @@ def caption_images( ): # Check for caption_text_input # if caption_text_input == "": - # show_message_box("Caption text is missing...") + # msgbox("Caption text is missing...") # return # Check for images_dir_input if train_data_dir == '': - show_message_box('Image folder is missing...') + msgbox('Image folder is missing...') return if caption_extension == '': - show_message_box('Please provide an extension for the caption files.') + msgbox('Please provide an extension for the caption files.') return print(f'Captioning files in {train_data_dir}...') diff --git a/lora_gui.py b/lora_gui.py index 03b22d1..ccca947 100644 --- a/lora_gui.py +++ b/lora_gui.py @@ -3,16 +3,15 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import argparse +import gradio as gr +import easygui import json import math import os -import pathlib import subprocess - -import gradio as gr - -from library.common_gui_functions import ( +import pathlib +import argparse +from library.common_gui import ( get_folder_path, remove_doublequote, get_file_path, @@ -28,23 +27,24 @@ from library.common_gui_functions import ( run_cmd_training, # set_legacy_8bitadam, update_my_data, - check_if_model_exist, show_message_box, + check_if_model_exist, ) -from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.dreambooth_folder_creation_gui import ( gradio_dreambooth_folder_creation_tab, ) -from library.merge_lora_gui import gradio_merge_lora_tab -from library.resize_lora_gui import gradio_resize_lora_tab -from library.sampler_gui import sample_gradio_config, run_cmd_sample -from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) +from library.dataset_balancing_gui import gradio_dataset_balancing_tab from library.utilities import utilities_tab +from library.merge_lora_gui import gradio_merge_lora_tab +from library.svd_merge_lora_gui import gradio_svd_merge_lora_tab from library.verify_lora_gui import gradio_verify_lora_tab +from library.resize_lora_gui import gradio_resize_lora_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -359,35 +359,35 @@ def train_model( print_only_bool = True if print_only.get('label') == 'True' else False if pretrained_model_name_or_path == '': - show_message_box('Source model information is missing') + msgbox('Source model information is missing') return if train_data_dir == '': - show_message_box('Image folder path is missing') + msgbox('Image folder path is missing') return if not os.path.exists(train_data_dir): - show_message_box('Image folder does not exist') + msgbox('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - show_message_box('Regularisation folder does not exist') + msgbox('Regularisation folder does not exist') return if output_dir == '': - show_message_box('Output folder path is missing') + msgbox('Output folder path is missing') return if int(bucket_reso_steps) < 1: - show_message_box('Bucket resolution steps need to be greater than 0') + msgbox('Bucket resolution steps need to be greater than 0') return if not os.path.exists(output_dir): os.makedirs(output_dir) if stop_text_encoder_training_pct > 0: - show_message_box( + msgbox( 'Output "stop text encoder training" is not yet supported. Ignoring' ) stop_text_encoder_training_pct = 0 @@ -402,7 +402,7 @@ def train_model( unet_lr = 0 # if (float(text_encoder_lr) == 0) and (float(unet_lr) == 0): - # show_message_box( + # msgbox( # 'At least one Learning Rate value for "Text encoder" or "Unet" need to be provided' # ) # return @@ -540,7 +540,7 @@ def train_model( run_cmd += f' --network_train_unet_only' else: if float(text_encoder_lr) == 0: - show_message_box('Please input learning rate values.') + msgbox('Please input learning rate values.') return run_cmd += f' --network_dim={network_dim}' @@ -1031,14 +1031,14 @@ def lora_tab( ] button_open_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list + [LoCon_row], show_progress=False, diff --git a/textual_inversion_gui.py b/textual_inversion_gui.py index f8aaa36..da5467d 100644 --- a/textual_inversion_gui.py +++ b/textual_inversion_gui.py @@ -3,16 +3,14 @@ # v3: Add new Utilities tab for Dreambooth folder preparation # v3.1: Adding captionning of images to utilities -import argparse +import gradio as gr import json import math import os -import pathlib import subprocess - -import gradio as gr - -from library.common_gui_functions import ( +import pathlib +import argparse +from library.common_gui import ( get_folder_path, remove_doublequote, get_file_path, @@ -30,16 +28,17 @@ from library.common_gui_functions import ( update_my_data, check_if_model_exist, ) -from library.dreambooth_folder_creation_gui import ( - gradio_dreambooth_folder_creation_tab, -) -from library.sampler_gui import sample_gradio_config, run_cmd_sample from library.tensorboard_gui import ( gradio_tensorboard, start_tensorboard, stop_tensorboard, ) +from library.dreambooth_folder_creation_gui import ( + gradio_dreambooth_folder_creation_tab, +) from library.utilities import utilities_tab +from library.sampler_gui import sample_gradio_config, run_cmd_sample +from easygui import msgbox folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -242,7 +241,7 @@ def open_configuration( if ask_for_file: file_path = get_file_path(file_path) - if not file_path == '' and file_path is not None: + if not file_path == '' and not file_path == None: # load variables from JSON file with open(file_path, 'r') as f: my_data = json.load(f) @@ -330,32 +329,32 @@ def train_model( min_snr_gamma, ): if pretrained_model_name_or_path == '': - show_message_box('Source model information is missing') + msgbox('Source model information is missing') return if train_data_dir == '': - show_message_box('Image folder path is missing') + msgbox('Image folder path is missing') return if not os.path.exists(train_data_dir): - show_message_box('Image folder does not exist') + msgbox('Image folder does not exist') return if reg_data_dir != '': if not os.path.exists(reg_data_dir): - show_message_box('Regularisation folder does not exist') + msgbox('Regularisation folder does not exist') return if output_dir == '': - show_message_box('Output folder path is missing') + msgbox('Output folder path is missing') return if token_string == '': - show_message_box('Token string is missing') + msgbox('Token string is missing') return if init_word == '': - show_message_box('Init word is missing') + msgbox('Init word is missing') return if not os.path.exists(output_dir): @@ -673,7 +672,7 @@ def ti_tab( ) weights_file_input = gr.Button('📂', elem_id='open_folder_small') weights_file_input.click( - lambda *args, **kwargs: get_file_path(*args), + get_file_path, outputs=weights, show_progress=False, ) @@ -899,14 +898,14 @@ def ti_tab( ] button_open_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_true, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, ) button_load_config.click( - lambda *args, **kwargs: open_configuration(), + open_configuration, inputs=[dummy_db_false, config_file_name] + settings_list, outputs=[config_file_name] + settings_list, show_progress=False, diff --git a/train_db.py b/train_db.py index 37fa38e..b3eead9 100644 --- a/train_db.py +++ b/train_db.py @@ -1,25 +1,28 @@ # DreamBooth training # XXX dropped option: fine_tune -import argparse import gc +import time +import argparse import itertools import math import os +import toml from multiprocessing import Value +from tqdm import tqdm import torch from accelerate.utils import set_seed +import diffusers from diffusers import DDPMScheduler -from tqdm import tqdm -import library.config_ml_util as config_util -import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.config_ml_util import ( +import library.config_util as config_util +from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b6f56a4..f279370 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,21 +1,24 @@ +import importlib import argparse import gc import math import os +import toml from multiprocessing import Value +from tqdm import tqdm import torch from accelerate.utils import set_seed +import diffusers from diffusers import DDPMScheduler -from tqdm import tqdm -import library.config_ml_util as config_util -import library.custom_train_functions as custom_train_functions import library.train_util as train_util -from library.config_ml_util import ( +import library.config_util as config_util +from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) +import library.custom_train_functions as custom_train_functions from library.custom_train_functions import apply_snr_weight imagenet_templates_small = [