Fix basic captioning logic

This commit is contained in:
bmaltais 2023-03-15 19:31:52 -04:00
parent 7a94c523f5
commit baf009d2b1
5 changed files with 254 additions and 129 deletions

View File

@ -189,6 +189,8 @@ This will store your a backup file with your current locally installed pip packa
## Change History
* 2023/03/16 (v21.2.5):
- Fix basic captioning logic
* 2023/03/12 (v21.2.4):
- Fix issue with kohya locon not training the convolution layers
- Update LyCORIS module version

View File

@ -6,35 +6,33 @@ import os
def caption_images(
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
caption_text,
images_dir,
overwrite,
caption_ext,
prefix,
postfix,
find,
replace,
find_text,
replace_text,
):
# Check for images_dir_input
if images_dir_input == '':
# Check for images_dir
if not images_dir:
msgbox('Image folder is missing...')
return
if caption_file_ext == '':
if not caption_ext:
msgbox('Please provide an extension for the caption files.')
return
if not caption_text_input == '':
print(
f'Captioning files in {images_dir_input} with {caption_text_input}...'
)
if caption_text:
print(f'Captioning files in {images_dir} with {caption_text}...')
run_cmd = f'python "tools/caption.py"'
run_cmd += f' --caption_text="{caption_text_input}"'
if overwrite_input:
run_cmd += f' --caption_text="{caption_text}"'
if overwrite:
run_cmd += f' --overwrite'
if caption_file_ext != '':
run_cmd += f' --caption_file_ext="{caption_file_ext}"'
run_cmd += f' "{images_dir_input}"'
if caption_ext:
run_cmd += f' --caption_file_ext="{caption_ext}"'
run_cmd += f' "{images_dir}"'
print(run_cmd)
@ -44,24 +42,24 @@ def caption_images(
else:
subprocess.run(run_cmd)
if overwrite_input:
if not prefix == '' or not postfix == '':
if overwrite:
if prefix or postfix:
# Add prefix and postfix
add_pre_postfix(
folder=images_dir_input,
caption_file_ext=caption_file_ext,
folder=images_dir,
caption_file_ext=caption_ext,
prefix=prefix,
postfix=postfix,
)
if not find == '':
if find_text:
find_replace(
folder=images_dir_input,
caption_file_ext=caption_file_ext,
find=find,
replace=replace,
folder=images_dir,
caption_file_ext=caption_ext,
find=find_text,
replace=replace_text,
)
else:
if not prefix == '' or not postfix == '':
if prefix or postfix:
msgbox(
'Could not modify caption files with requested change because the "Overwrite existing captions in folder" option is not selected...'
)
@ -69,37 +67,31 @@ def caption_images(
print('...captioning done')
###
# Gradio UI
###
def gradio_basic_caption_gui_tab():
with gr.Tab('Basic Captioning'):
gr.Markdown(
'This utility will allow the creation of simple caption files for each images in a folder.'
'This utility will allow the creation of simple caption files for each image in a folder.'
)
with gr.Row():
images_dir_input = gr.Textbox(
images_dir = gr.Textbox(
label='Image folder to caption',
placeholder='Directory containing the images to caption',
interactive=True,
)
button_images_dir_input = gr.Button(
'📂', elem_id='open_folder_small'
)
button_images_dir_input.click(
folder_button = gr.Button('📂', elem_id='open_folder_small')
folder_button.click(
get_folder_path,
outputs=images_dir_input,
outputs=images_dir,
show_progress=False,
)
caption_file_ext = gr.Textbox(
caption_ext = gr.Textbox(
label='Caption file extension',
placeholder='Extention for caption file. eg: .caption, .txt',
placeholder='Extension for caption file. eg: .caption, .txt',
value='.txt',
interactive=True,
)
overwrite_input = gr.Checkbox(
overwrite = gr.Checkbox(
label='Overwrite existing captions in folder',
interactive=True,
value=False,
@ -110,7 +102,7 @@ def gradio_basic_caption_gui_tab():
placeholder='(Optional)',
interactive=True,
)
caption_text_input = gr.Textbox(
caption_text = gr.Textbox(
label='Caption text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True,
@ -121,29 +113,28 @@ def gradio_basic_caption_gui_tab():
interactive=True,
)
with gr.Row():
find = gr.Textbox(
find_text = gr.Textbox(
label='Find text',
placeholder='Eg: , by some artist. Leave empty if you just want to add pre or postfix',
interactive=True,
)
replace = gr.Textbox(
replace_text = gr.Textbox(
label='Replacement text',
placeholder='Eg: , by some artist. Leave empty if you just want to replace with nothing',
interactive=True,
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
caption_text,
images_dir,
overwrite,
caption_ext,
prefix,
postfix,
find_text,
replace_text,
],
show_progress=False,
)
caption_button = gr.Button('Caption images')
caption_button.click(
caption_images,
inputs=[
caption_text_input,
images_dir_input,
overwrite_input,
caption_file_ext,
prefix,
postfix,
find,
replace,
],
show_progress=False,
)

View File

@ -32,77 +32,112 @@ ALL_PRESET_MODELS = V2_BASE_MODELS + V_PARAMETERIZATION_MODELS + V1_MODELS
def update_my_data(my_data):
if my_data.get('use_8bit_adam', False) == True:
# Update optimizer based on use_8bit_adam flag
use_8bit_adam = my_data.get('use_8bit_adam', False)
if use_8bit_adam:
my_data['optimizer'] = 'AdamW8bit'
# my_data['use_8bit_adam'] = False
if (
my_data.get('optimizer', 'missing') == 'missing'
and my_data.get('use_8bit_adam', False) == False
):
elif 'optimizer' not in my_data:
my_data['optimizer'] = 'AdamW'
if my_data.get('model_list', 'custom') == []:
print('Old config with empty model list. Setting to custom...')
my_data['model_list'] = 'custom'
# If Pretrained model name or path is not one of the preset models then set the preset_model to custom
if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Fix old config files that contain epoch as str instead of int
# Update model_list to custom if empty or pretrained_model_name_or_path is not a preset model
model_list = my_data.get('model_list', [])
pretrained_model_name_or_path = my_data.get('pretrained_model_name_or_path', '')
if not model_list or pretrained_model_name_or_path not in ALL_PRESET_MODELS:
my_data['model_list'] = 'custom'
# Convert epoch and save_every_n_epochs values to int if they are strings
for key in ['epoch', 'save_every_n_epochs']:
value = my_data.get(key, -1)
if type(value) == str:
if value != '':
my_data[key] = int(value)
else:
my_data[key] = -1
if isinstance(value, str) and value:
my_data[key] = int(value)
elif not value:
my_data[key] = -1
# Update LoRA_type if it is set to LoCon
if my_data.get('LoRA_type', 'Standard') == 'LoCon':
my_data['LoRA_type'] = 'LyCORIS/LoCon'
return my_data
# def update_my_data(my_data):
# if my_data.get('use_8bit_adam', False) == True:
# my_data['optimizer'] = 'AdamW8bit'
# # my_data['use_8bit_adam'] = False
# if (
# my_data.get('optimizer', 'missing') == 'missing'
# and my_data.get('use_8bit_adam', False) == False
# ):
# my_data['optimizer'] = 'AdamW'
# if my_data.get('model_list', 'custom') == []:
# print('Old config with empty model list. Setting to custom...')
# my_data['model_list'] = 'custom'
# # If Pretrained model name or path is not one of the preset models then set the preset_model to custom
# if not my_data.get('pretrained_model_name_or_path', '') in ALL_PRESET_MODELS:
# my_data['model_list'] = 'custom'
# # Fix old config files that contain epoch as str instead of int
# for key in ['epoch', 'save_every_n_epochs']:
# value = my_data.get(key, -1)
# if type(value) == str:
# if value != '':
# my_data[key] = int(value)
# else:
# my_data[key] = -1
# if my_data.get('LoRA_type', 'Standard') == 'LoCon':
# my_data['LoRA_type'] = 'LyCORIS/LoCon'
# return my_data
def get_dir_and_file(file_path):
dir_path, file_name = os.path.split(file_path)
return (dir_path, file_name)
def has_ext_files(directory, extension):
# Iterate through all the files in the directory
for file in os.listdir(directory):
# If the file name ends with extension, return True
if file.endswith(extension):
return True
# If no extension files were found, return False
return False
# def has_ext_files(directory, extension):
# # Iterate through all the files in the directory
# for file in os.listdir(directory):
# # If the file name ends with extension, return True
# if file.endswith(extension):
# return True
# # If no extension files were found, return False
# return False
def get_file_path(
file_path='', defaultextension='.json', extension_name='Config files'
file_path='', default_extension='.json', extension_name='Config files'
):
current_file_path = file_path
# print(f'current file path: {current_file_path}')
initial_dir, initial_file = get_dir_and_file(file_path)
# Create a hidden Tkinter root window
root = Tk()
root.wm_attributes('-topmost', 1)
root.withdraw()
# Show the open file dialog and get the selected file path
file_path = filedialog.askopenfilename(
filetypes=(
(f'{extension_name}', f'{defaultextension}'),
('All files', '*'),
(extension_name, f'*{default_extension}'),
('All files', '*.*'),
),
defaultextension=defaultextension,
defaultextension=default_extension,
initialfile=initial_file,
initialdir=initial_dir,
)
# Destroy the hidden root window
root.destroy()
if file_path == '':
# If no file is selected, use the current file path
if not file_path:
file_path = current_file_path
return file_path
@ -230,52 +265,146 @@ def get_saveasfilename_path(
def add_pre_postfix(
folder='', prefix='', postfix='', caption_file_ext='.caption'
):
if not has_ext_files(folder, caption_file_ext):
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
)
return
folder: str = '',
prefix: str = '',
postfix: str = '',
caption_file_ext: str = '.caption'
) -> None:
"""
Add prefix and/or postfix to the content of caption files within a folder.
If no caption files are found, create one with the requested prefix and/or postfix.
Args:
folder (str): Path to the folder containing caption files.
prefix (str, optional): Prefix to add to the content of the caption files.
postfix (str, optional): Postfix to add to the content of the caption files.
caption_file_ext (str, optional): Extension of the caption files.
"""
if prefix == '' and postfix == '':
return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
if not prefix == '':
prefix = f'{prefix} '
if not postfix == '':
postfix = f' {postfix}'
image_extensions = ('.jpg', '.jpeg', '.png', '.webp')
image_files = [f for f in os.listdir(folder) if f.lower().endswith(image_extensions)]
for file in files:
with open(os.path.join(folder, file), 'r+') as f:
content = f.read()
content = content.rstrip()
f.seek(0, 0)
f.write(f'{prefix}{content}{postfix}')
f.close()
for image_file in image_files:
caption_file_name = os.path.splitext(image_file)[0] + caption_file_ext
caption_file_path = os.path.join(folder, caption_file_name)
if not os.path.exists(caption_file_path):
with open(caption_file_path, 'w') as f:
separator = ' ' if prefix and postfix else ''
f.write(f'{prefix}{separator}{postfix}')
else:
with open(caption_file_path, 'r+') as f:
content = f.read()
content = content.rstrip()
f.seek(0, 0)
prefix_separator = ' ' if prefix else ''
postfix_separator = ' ' if postfix else ''
f.write(f'{prefix}{prefix_separator}{content}{postfix_separator}{postfix}')
# def add_pre_postfix(
# folder='', prefix='', postfix='', caption_file_ext='.caption'
# ):
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if prefix == '' and postfix == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# if not prefix == '':
# prefix = f'{prefix} '
# if not postfix == '':
# postfix = f' {postfix}'
# for file in files:
# with open(os.path.join(folder, file), 'r+') as f:
# content = f.read()
# content = content.rstrip()
# f.seek(0, 0)
# f.write(f'{prefix} {content} {postfix}')
# f.close()
def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
def has_ext_files(folder_path: str, file_extension: str) -> bool:
"""
Check if there are any files with the specified extension in the given folder.
Args:
folder_path (str): Path to the folder containing files.
file_extension (str): Extension of the files to look for.
Returns:
bool: True if files with the specified extension are found, False otherwise.
"""
for file in os.listdir(folder_path):
if file.endswith(file_extension):
return True
return False
def find_replace(
folder_path: str = '',
caption_file_ext: str = '.caption',
search_text: str = '',
replace_text: str = ''
) -> None:
"""
Find and replace text in caption files within a folder.
Args:
folder_path (str, optional): Path to the folder containing caption files.
caption_file_ext (str, optional): Extension of the caption files.
search_text (str, optional): Text to search for in the caption files.
replace_text (str, optional): Text to replace the search text with.
"""
print('Running caption find/replace')
if not has_ext_files(folder, caption_file_ext):
if not has_ext_files(folder_path, caption_file_ext):
msgbox(
f'No files with extension {caption_file_ext} were found in {folder}...'
f'No files with extension {caption_file_ext} were found in {folder_path}...'
)
return
if find == '':
if search_text == '':
return
files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
for file in files:
with open(os.path.join(folder, file), 'r', errors='ignore') as f:
caption_files = [f for f in os.listdir(folder_path) if f.endswith(caption_file_ext)]
for caption_file in caption_files:
with open(os.path.join(folder_path, caption_file), 'r', errors='ignore') as f:
content = f.read()
f.close
content = content.replace(find, replace)
with open(os.path.join(folder, file), 'w') as f:
content = content.replace(search_text, replace_text)
with open(os.path.join(folder_path, caption_file), 'w') as f:
f.write(content)
f.close()
# def find_replace(folder='', caption_file_ext='.caption', find='', replace=''):
# print('Running caption find/replace')
# if not has_ext_files(folder, caption_file_ext):
# msgbox(
# f'No files with extension {caption_file_ext} were found in {folder}...'
# )
# return
# if find == '':
# return
# files = [f for f in os.listdir(folder) if f.endswith(caption_file_ext)]
# for file in files:
# with open(os.path.join(folder, file), 'r', errors='ignore') as f:
# content = f.read()
# f.close
# content = content.replace(find, replace)
# with open(os.path.join(folder, file), 'w') as f:
# f.write(content)
# f.close()
def color_aug_changed(color_aug):

View File

@ -417,13 +417,16 @@ def train_model(
or f.endswith('.webp')
]
)
print(f'Folder {folder}: {num_images} images found')
# Calculate the total number of steps for this folder
steps = repeats * num_images
total_steps += steps
# Print the result
print(f'Folder {folder}: {steps} steps')
total_steps += steps
# calculate max_train_steps
max_train_steps = int(

View File

@ -116,7 +116,7 @@ def main():
linear_mode_param, conv_mode_param,
args.device,
args.use_sparse_bias, args.sparsity,
# not args.disable_small_conv
not args.disable_cp
)
if args.safetensors: