diff --git a/modules/call_queue.py b/modules/call_queue.py new file mode 100644 index 00000000..4cd49533 --- /dev/null +++ b/modules/call_queue.py @@ -0,0 +1,98 @@ +import html +import sys +import threading +import traceback +import time + +from modules import shared + +queue_lock = threading.Lock() + + +def wrap_queued_call(func): + def f(*args, **kwargs): + with queue_lock: + res = func(*args, **kwargs) + + return res + + return f + + +def wrap_gradio_gpu_call(func, extra_outputs=None): + def f(*args, **kwargs): + + shared.state.begin() + + with queue_lock: + res = func(*args, **kwargs) + + shared.state.end() + + return res + + return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) + + +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): + run_memmon = shared.opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats + if run_memmon: + shared.mem_mon.monitor() + t = time.perf_counter() + + try: + res = list(func(*args, **kwargs)) + except Exception as e: + # When printing out our debug argument list, do not print out more than a MB of text + max_debug_str_len = 131072 # (1024*1024)/8 + + print("Error completing request", file=sys.stderr) + argStr = f"Arguments: {str(args)} {str(kwargs)}" + print(argStr[:max_debug_str_len], file=sys.stderr) + if len(argStr) > max_debug_str_len: + print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) + + print(traceback.format_exc(), file=sys.stderr) + + shared.state.job = "" + shared.state.job_count = 0 + + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{html.escape(type(e).__name__+': '+str(e))}
"] + + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + + elapsed = time.perf_counter() - t + elapsed_m = int(elapsed // 60) + elapsed_s = elapsed % 60 + elapsed_text = f"{elapsed_s:.2f}s" + if elapsed_m > 0: + elapsed_text = f"{elapsed_m}m "+elapsed_text + + if run_memmon: + mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} + active_peak = mem_stats['active_peak'] + reserved_peak = mem_stats['reserved_peak'] + sys_peak = mem_stats['system_peak'] + sys_total = mem_stats['total'] + sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) + + vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" + else: + vram_html = '' + + # last item is always HTML + res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + + return tuple(res) + + return f + diff --git a/modules/ui.py b/modules/ui.py index 446bee40..00809361 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -17,7 +17,7 @@ import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin - +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru from modules.paths import script_path @@ -158,67 +158,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func, extra_outputs=None, add_stats=False): - def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats - if run_memmon: - shared.mem_mon.monitor() - t = time.perf_counter() - - try: - res = list(func(*args, **kwargs)) - except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {str(args)} {str(kwargs)}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) - - shared.state.job = "" - shared.state.job_count = 0 - - if extra_outputs_array is None: - extra_outputs_array = [None, ''] - - res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] - - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - - if not add_stats: - return tuple(res) - - elapsed = time.perf_counter() - t - elapsed_m = int(elapsed // 60) - elapsed_s = elapsed % 60 - elapsed_text = f"{elapsed_s:.2f}s" - if elapsed_m > 0: - elapsed_text = f"{elapsed_m}m "+elapsed_text - - if run_memmon: - mem_stats = {k: -(v//-(1024*1024)) for k, v in shared.mem_mon.stop().items()} - active_peak = mem_stats['active_peak'] - reserved_peak = mem_stats['reserved_peak'] - sys_peak = mem_stats['system_peak'] - sys_total = mem_stats['total'] - sys_pct = round(sys_peak/max(sys_total, 1) * 100, 2) - - vram_html = f"

Torch active/reserved: {active_peak}/{reserved_peak} MiB, Sys VRAM: {sys_peak}/{sys_total} MiB ({sys_pct}%)

" - else: - vram_html = '' - - # last item is always HTML - res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - - return tuple(res) - - return f def calc_time_left(progress, threshold, label, force_display): @@ -666,7 +605,7 @@ Requested path was: {f} return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info -def create_ui(wrap_gradio_gpu_call): +def create_ui(): import modules.img2img import modules.txt2img @@ -826,7 +765,7 @@ def create_ui(wrap_gradio_gpu_call): height, ] - token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) diff --git a/webui.py b/webui.py index 7a56bde8..16e7ec1a 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir @@ -32,38 +33,12 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork -queue_lock = threading.Lock() if cmd_opts.server_name: server_name = cmd_opts.server_name else: server_name = "0.0.0.0" if cmd_opts.listen else None -def wrap_queued_call(func): - def f(*args, **kwargs): - with queue_lock: - res = func(*args, **kwargs) - - return res - - return f - - -def wrap_gradio_gpu_call(func, extra_outputs=None): - def f(*args, **kwargs): - - shared.state.begin() - - with queue_lock: - res = func(*args, **kwargs) - - shared.state.end() - - return res - - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) - - def initialize(): extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) @@ -159,7 +134,7 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() - shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, @@ -189,6 +164,7 @@ def webui(): create_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) + modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo)