Added Refresh Button to embedding and hypernetwork names in Train Tab

Problem
everytime I modified pt files in embedding_dir or hypernetwork_dir, I
need to restart webui to have the new files shown in the dropdown of
Train Tab

Solution
refactored create_refresh_button out of create_setting_component so we
can use this method to create button next to gr.Dropdowns of embedding
name and hypernetworks

Extra Modification
hypernetwork pt are now sorted in alphabetic order
This commit is contained in:
Junpeng Qiu 2022-10-15 21:42:52 -07:00 committed by AUTOMATIC1111
parent bd4f0fb9d9
commit 36a0ba357a
2 changed files with 27 additions and 20 deletions

View File

@ -568,6 +568,24 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn = refresh,
inputs = [],
outputs = [refresh_component]
)
return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
@ -1205,8 +1223,12 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Tab(label="Train"): with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
@ -1357,26 +1379,11 @@ def create_ui(wrap_gradio_gpu_call):
if info.refresh is not None: if info.refresh is not None:
if is_quicksettings: if is_quicksettings:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key) refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
else: else:
with gr.Row(variant="compact"): with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key) refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key)
def refresh():
info.refresh()
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
for k, v in refreshed_args.items():
setattr(res, k, v)
return gr.update(**(refreshed_args or {}))
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[res],
)
else: else:
res = comp(label=info.label, value=fun, **(args or {})) res = comp(label=info.label, value=fun, **(args or {}))

View File

@ -478,7 +478,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;