This commit is contained in:
bmaltais 2023-03-01 19:24:11 -05:00
parent 5498539fda
commit 182080bb78
10 changed files with 4236 additions and 2730 deletions

View File

@ -95,9 +95,11 @@ def save_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -194,9 +196,11 @@ def open_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -272,9 +276,11 @@ def train_model(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -566,7 +572,8 @@ def dreambooth_tab(
seed,
caption_extension,
cache_latents,
optimizer,optimizer_args,
optimizer,
optimizer_args,
) = gradio_training(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
@ -624,7 +631,9 @@ def dreambooth_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
noise_offset,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -648,15 +657,15 @@ def dreambooth_tab(
)
button_run = gr.Button('Train model', variant='primary')
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
button_start_tensorboard.click(
start_tensorboard,
inputs=logging_dir,
)
button_stop_tensorboard.click(
stop_tensorboard,
)
@ -710,8 +719,11 @@ def dreambooth_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
]
button_open_config.click(
@ -773,16 +785,20 @@ def UI(**kwargs):
)
# Show the interface
launch_kwargs={}
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -793,10 +809,20 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
)

View File

@ -91,8 +91,11 @@ def save_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -195,8 +198,11 @@ def open_config_file(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -278,8 +284,11 @@ def train_model(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
# create caption json file
if generate_caption_database:
@ -585,7 +594,8 @@ def finetune_tab():
seed,
caption_extension,
cache_latents,
optimizer,optimizer_args,
optimizer,
optimizer_args,
) = gradio_training(learning_rate_value='1e-5')
with gr.Row():
dataset_repeats = gr.Textbox(label='Dataset repeats', value=40)
@ -617,7 +627,9 @@ def finetune_tab():
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
noise_offset,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -631,15 +643,15 @@ def finetune_tab():
)
button_run = gr.Button('Train model', variant='primary')
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
button_start_tensorboard.click(
start_tensorboard,
inputs=logging_dir,
)
button_stop_tensorboard.click(
stop_tensorboard,
)
@ -699,8 +711,11 @@ def finetune_tab():
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
]
button_run.click(train_model, inputs=settings_list)
@ -742,16 +757,19 @@ def UI(**kwargs):
utilities_tab(enable_dreambooth_tab=False)
# Show the interface
launch_kwargs={}
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -763,10 +781,20 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
)

View File

@ -53,15 +53,16 @@ def UI(**kwargs):
inbrowser = kwargs.get('inbrowser', False)
share = kwargs.get('share', False)
if username and password:
launch_kwargs["auth"] = (username, password)
launch_kwargs['auth'] = (username, password)
if server_port > 0:
launch_kwargs["server_port"] = server_port
launch_kwargs['server_port'] = server_port
if inbrowser:
launch_kwargs["inbrowser"] = inbrowser
launch_kwargs['inbrowser'] = inbrowser
if share:
launch_kwargs["share"] = share
launch_kwargs['share'] = share
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
parser = argparse.ArgumentParser()
@ -72,11 +73,24 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
parser.add_argument("--share", action="store_true", help="Share the gradio UI")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port, share=args.share)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
share=args.share,
)

View File

@ -9,6 +9,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def update_optimizer(my_data):
if my_data.get('use_8bit_adam', False):
my_data['optimizer'] = 'AdamW8bit'
@ -86,13 +87,18 @@ def remove_doublequote(file_path):
return file_path
def set_legacy_8bitadam(optimizer, use_8bit_adam):
if optimizer == 'AdamW8bit':
# use_8bit_adam = True
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(value=True, interactive=False, visible=True)
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
value=True, interactive=False, visible=True
)
else:
# use_8bit_adam = False
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(value=False, interactive=False, visible=True)
return gr.Dropdown.update(value=optimizer), gr.Checkbox.update(
value=False, interactive=False, visible=True
)
def get_folder_path(folder_path=''):
@ -489,14 +495,15 @@ def gradio_training(
'DAdaptation',
'Lion',
'SGDNesterov',
'SGDNesterov8bit'
'SGDNesterov8bit',
],
value="AdamW8bit",
value='AdamW8bit',
interactive=True,
)
with gr.Row():
optimizer_args = gr.Textbox(
label='Optimizer extra arguments', placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True'
label='Optimizer extra arguments',
placeholder='(Optional) eg: relative_step=True scale_parameter=True warmup_init=True',
)
return (
learning_rate,
@ -549,11 +556,14 @@ def run_cmd_training(**kwargs):
' --cache_latents' if kwargs.get('cache_latents') else '',
# ' --use_lion_optimizer' if kwargs.get('optimizer') == 'Lion' else '',
f' --optimizer_type="{kwargs.get("optimizer", "AdamW")}"',
f' --optimizer_args {kwargs.get("optimizer_args", "")}' if not kwargs.get('optimizer_args') == '' else '',
f' --optimizer_args {kwargs.get("optimizer_args", "")}'
if not kwargs.get('optimizer_args') == ''
else '',
]
run_cmd = ''.join(options)
return run_cmd
# # This function takes a dictionary of keyword arguments and returns a string that can be used to run a command-line training script
# def run_cmd_training(**kwargs):
# arg_map = {
@ -611,7 +621,9 @@ def gradio_advanced_training():
)
with gr.Row():
# This use_8bit_adam element should be removed in a future release as it is no longer used
use_8bit_adam = gr.Checkbox(label='Use 8bit adam', value=False, visible=False)
use_8bit_adam = gr.Checkbox(
label='Use 8bit adam', value=False, visible=False
)
xformers = gr.Checkbox(label='Use xformers', value=True)
color_aug = gr.Checkbox(label='Color augmentation', value=False)
flip_aug = gr.Checkbox(label='Flip augmentation', value=False)
@ -628,17 +640,13 @@ def gradio_advanced_training():
noise_offset = gr.Textbox(
label='Noise offset (0 - 1)', placeholder='(Oprional) eg: 0.1'
)
with gr.Row():
caption_dropout_every_n_epochs = gr.Number(
label="Dropout caption every n epochs",
value=0
label='Dropout caption every n epochs', value=0
)
caption_dropout_rate = gr.Slider(
label="Rate of caption dropout",
value=0,
minimum=0,
maximum=1
label='Rate of caption dropout', value=0, minimum=0, maximum=1
)
with gr.Row():
save_state = gr.Checkbox(label='Save training state', value=False)
@ -676,7 +684,9 @@ def gradio_advanced_training():
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
noise_offset,
)
@ -706,11 +716,9 @@ def run_cmd_advanced_training(**kwargs):
f' --caption_dropout_rate="{kwargs.get("caption_dropout_rate", "")}"'
if float(kwargs.get('caption_dropout_rate', 0)) > 0
else '',
f' --bucket_reso_steps={int(kwargs.get("bucket_reso_steps", 1))}'
if int(kwargs.get('bucket_reso_steps', 64)) >= 1
else '',
' --save_state' if kwargs.get('save_state') else '',
' --mem_eff_attn' if kwargs.get('mem_eff_attn') else '',
' --color_aug' if kwargs.get('color_aug') else '',
@ -734,6 +742,7 @@ def run_cmd_advanced_training(**kwargs):
run_cmd = ''.join(options)
return run_cmd
# def run_cmd_advanced_training(**kwargs):
# arg_map = {
# 'max_train_epochs': ' --max_train_epochs="{}"',
@ -763,4 +772,4 @@ def run_cmd_advanced_training(**kwargs):
# cmd = ''.join(options)
# return cmd
# return cmd

File diff suppressed because it is too large Load Diff

View File

@ -4,43 +4,49 @@ from easygui import msgbox
import subprocess
import time
tensorboard_proc = None # I know... bad but heh
tensorboard_proc = None # I know... bad but heh
def start_tensorboard(logging_dir):
global tensorboard_proc
if not os.listdir(logging_dir):
print("Error: log folder is empty")
msgbox(msg="Error: log folder is empty")
print('Error: log folder is empty')
msgbox(msg='Error: log folder is empty')
return
run_cmd = f'tensorboard.exe --logdir "{logging_dir}"'
print(run_cmd)
if tensorboard_proc is not None:
print("Tensorboard is already running. Terminating existing process before starting new one...")
print(
'Tensorboard is already running. Terminating existing process before starting new one...'
)
stop_tensorboard()
# Start background process
print('Starting tensorboard...')
print('Starting tensorboard...')
tensorboard_proc = subprocess.Popen(run_cmd)
# Wait for some time to allow TensorBoard to start up
time.sleep(5)
# Open the TensorBoard URL in the default browser
print('Opening tensorboard url in browser...')
import webbrowser
webbrowser.open('http://localhost:6006')
def stop_tensorboard():
print('Stopping tensorboard process...')
tensorboard_proc.kill()
print('...process stopped')
def gradio_tensorboard():
with gr.Row():
button_start_tensorboard = gr.Button('Start tensorboard')
button_stop_tensorboard = gr.Button('Stop tensorboard')
return(button_start_tensorboard, button_stop_tensorboard)
return (button_start_tensorboard, button_stop_tensorboard)

File diff suppressed because it is too large Load Diff

View File

@ -50,16 +50,19 @@ def UI(**kwargs):
utilities_tab()
# Show the interface
launch_kwargs={}
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -71,10 +74,20 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
)

View File

@ -47,6 +47,7 @@ refresh_symbol = '\U0001f504' # 🔄
save_style_symbol = '\U0001f4be' # 💾
document_symbol = '\U0001F4C4' # 📄
def save_configuration(
save_as,
file_path,
@ -105,9 +106,11 @@ def save_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -211,9 +214,11 @@ def open_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -237,7 +242,7 @@ def open_configuration(
# Set the value in the dictionary to the corresponding value in `my_data`, or the default value if not found
if not key in ['file_path']:
values.append(my_data.get(key, value))
return tuple(values)
@ -297,10 +302,12 @@ def train_model(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
):
optimizer_args,
noise_offset,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
return
@ -723,14 +730,16 @@ def lora_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
noise_offset,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
inputs=[color_aug],
outputs=[cache_latents],
)
optimizer.change(
set_legacy_8bitadam,
inputs=[optimizer, use_8bit_adam],
@ -753,15 +762,15 @@ def lora_tab(
gradio_verify_lora_tab()
button_run = gr.Button('Train model', variant='primary')
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
button_start_tensorboard.click(
start_tensorboard,
inputs=logging_dir,
)
button_stop_tensorboard.click(
stop_tensorboard,
)
@ -822,9 +831,11 @@ def lora_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,noise_offset,
optimizer_args,
noise_offset,
]
button_open_config.click(
@ -886,16 +897,19 @@ def UI(**kwargs):
)
# Show the interface
launch_kwargs={}
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -907,10 +921,20 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
)

View File

@ -101,8 +101,11 @@ def save_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -205,8 +208,11 @@ def open_configuration(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
# Get list of function parameters and values
parameters = list(locals().items())
@ -288,8 +294,11 @@ def train_model(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
):
if pretrained_model_name_or_path == '':
msgbox('Source model information is missing')
@ -641,7 +650,8 @@ def ti_tab(
seed,
caption_extension,
cache_latents,
optimizer,optimizer_args,
optimizer,
optimizer_args,
) = gradio_training(
learning_rate_value='1e-5',
lr_scheduler_value='cosine',
@ -699,7 +709,9 @@ def ti_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
noise_offset,
) = gradio_advanced_training()
color_aug.change(
color_aug_changed,
@ -723,15 +735,15 @@ def ti_tab(
)
button_run = gr.Button('Train model', variant='primary')
# Setup gradio tensorboard buttons
button_start_tensorboard, button_stop_tensorboard = gradio_tensorboard()
button_start_tensorboard.click(
start_tensorboard,
inputs=logging_dir,
)
button_stop_tensorboard.click(
stop_tensorboard,
)
@ -791,8 +803,11 @@ def ti_tab(
bucket_no_upscale,
random_crop,
bucket_reso_steps,
caption_dropout_every_n_epochs, caption_dropout_rate,
optimizer,optimizer_args,noise_offset,
caption_dropout_every_n_epochs,
caption_dropout_rate,
optimizer,
optimizer_args,
noise_offset,
]
button_open_config.click(
@ -854,16 +869,19 @@ def UI(**kwargs):
)
# Show the interface
launch_kwargs={}
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs["auth"] = (kwargs.get('username', None), kwargs.get('password', None))
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs["server_port"] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs["inbrowser"] = kwargs.get('inbrowser', False)
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('inbrowser', False):
launch_kwargs['inbrowser'] = kwargs.get('inbrowser', False)
print(launch_kwargs)
interface.launch(**launch_kwargs)
if __name__ == '__main__':
# torch.cuda.set_per_process_memory_fraction(0.48)
@ -875,10 +893,20 @@ if __name__ == '__main__':
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port', type=int, default=0, help='Port to run the server listener on'
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument("--inbrowser", action="store_true", help="Open in browser")
args = parser.parse_args()
UI(username=args.username, password=args.password, inbrowser=args.inbrowser, server_port=args.server_port)
UI(
username=args.username,
password=args.password,
inbrowser=args.inbrowser,
server_port=args.server_port,
)