Merge pull request #10557 from akx/dedupe-webui-boot

Refactor & deduplicate web UI boot code
This commit is contained in:
AUTOMATIC1111 2023-05-20 22:24:15 +03:00 committed by GitHub
commit cc6c0fc70a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 167 additions and 136 deletions

View File

@ -14,6 +14,11 @@ def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network extra_network_registry[extra_network.name] = extra_network
def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet())
class ExtraNetworkParams: class ExtraNetworkParams:
def __init__(self, items=None): def __init__(self, items=None):
self.items = items or [] self.items = items or []

View File

@ -271,6 +271,12 @@ def load_scripts():
sys.path = syspath sys.path = syspath
current_basedir = paths.script_path current_basedir = paths.script_path
global scripts_txt2img, scripts_img2img, scripts_postproc
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try: try:
@ -527,9 +533,9 @@ class ScriptRunner:
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner() scripts_txt2img: ScriptRunner = None
scripts_img2img = ScriptRunner() scripts_img2img: ScriptRunner = None
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None
scripts_current: ScriptRunner = None scripts_current: ScriptRunner = None
@ -539,14 +545,7 @@ def reload_script_body_only():
scripts_img2img.reload_sources(cache) scripts_img2img.reload_sources(cache)
def reload_scripts(): reload_scripts = load_scripts # compatibility alias
global scripts_txt2img, scripts_img2img, scripts_postproc
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def add_classes_to_gradio_component(comp): def add_classes_to_gradio_component(comp):

View File

@ -98,7 +98,6 @@ def setup_model():
if not os.path.exists(model_path): if not os.path.exists(model_path):
os.makedirs(model_path) os.makedirs(model_path)
list_models()
enable_midas_autodownload() enable_midas_autodownload()

View File

@ -232,10 +232,19 @@ class ExtraNetworksPage:
return None return None
def intialize(): def initialize():
extra_pages.clear() extra_pages.clear()
def register_default_pages():
from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion
from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks
from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints
register_page(ExtraNetworksPageTextualInversion())
register_page(ExtraNetworksPageHypernetworks())
register_page(ExtraNetworksPageCheckpoints())
class ExtraNetworksUi: class ExtraNetworksUi:
def __init__(self): def __init__(self):
self.pages = None self.pages = None

265
webui.py
View File

@ -7,6 +7,7 @@ import re
import warnings import warnings
import json import json
from threading import Thread from threading import Thread
from typing import Iterable
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -14,6 +15,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from packaging import version from packaging import version
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors # noqa: F401 from modules import paths, timer, import_hook, errors # noqa: F401
@ -34,8 +36,7 @@ startup_timer.record("import gradio")
import ldm.modules.encoders.modules # noqa: F401 import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm") startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints from modules import extra_networks
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
@ -162,80 +163,50 @@ def restore_config_state_file():
print(f"!!! Config state backup not found: {config_state_file}") print(f"!!! Config state backup not found: {config_state_file}")
def initialize(): def validate_tls_options():
fix_asyncio_event_loop_policy() if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
check_versions()
extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
restore_config_state_file()
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
return return
modelloader.cleanup_models() try:
modules.sd_models.setup_model() if not os.path.exists(cmd_opts.tls_keyfile):
startup_timer.record("list SD models") print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
startup_timer.record("TLS")
codeformer.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
gfpgan.setup_model(cmd_opts.gfpgan_models_path) def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]:
startup_timer.record("setup gfpgan") """
Convert the gradio_auth and gradio_auth_path commandline arguments into
an iterable of (username, password) tuples.
"""
def process_credential_line(s) -> tuple[str, ...] | None:
s = s.strip()
if not s:
return None
return tuple(s.split(':', 1))
modules.scripts.load_scripts() if cmd_opts.gradio_auth:
startup_timer.record("load scripts") for cred in cmd_opts.gradio_auth.split(','):
cred = process_credential_line(cred)
if cred:
yield cred
modelloader.load_upscalers() if cmd_opts.gradio_auth_path:
startup_timer.record("load upscalers") with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
for cred in line.strip().split(','):
cred = process_credential_line(cred)
if cred:
yield cred
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
# load model in parallel to other startup stuff
Thread(target=lambda: shared.sd_model).start()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange")
shared.reload_hypernetworks()
startup_timer.record("reload hypernets")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("extra networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
startup_timer.record("TLS")
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
@ -247,16 +218,104 @@ def initialize():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def configure_opts_onchange():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
startup_timer.record("opts onchange")
def initialize():
fix_asyncio_event_loop_policy()
validate_tls_options()
configure_sigint_handler()
check_versions()
modelloader.cleanup_models()
configure_opts_onchange()
modules.sd_models.setup_model()
startup_timer.record("setup SD model")
codeformer.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
sd_samplers.set_samplers()
extensions.list_extensions()
startup_timer.record("list extensions")
restore_config_state_file()
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts()
return
modules.sd_models.list_models()
startup_timer.record("list SD models")
localization.list_localizations(cmd_opts.localizations_dir)
modules.scripts.load_scripts()
startup_timer.record("load scripts")
if reload_script_modules:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
# load model in parallel to other startup stuff
# (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
ui_extra_networks.initialize()
ui_extra_networks.register_default_pages()
extra_networks.initialize()
extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")
def setup_middleware(app): def setup_middleware(app):
app.middleware_stack = None # reset current middleware to allow modifying user provided list app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: configure_cors_middleware(app)
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) app.build_middleware_stack() # rebuild middleware stack on-the-fly
elif cmd_opts.cors_allow_origins:
app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*'])
elif cmd_opts.cors_allow_origins_regex: def configure_cors_middleware(app):
app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) cors_options = {
app.build_middleware_stack() # rebuild middleware stack on-the-fly "allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.cors_allow_origins:
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
if cmd_opts.cors_allow_origins_regex:
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
def create_api(app): def create_api(app):
@ -301,16 +360,11 @@ def webui():
if not cmd_opts.no_gradio_queue: if not cmd_opts.no_gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
gradio_auth_creds = [] gradio_auth_creds = list(get_gradio_auth_creds()) or None
if cmd_opts.gradio_auth:
gradio_auth_creds += [x.strip() for x in cmd_opts.gradio_auth.strip('"').replace('\n', '').split(',') if x.strip()]
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
# this restores the missing /docs endpoint # this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'): if launch_api and not hasattr(FastAPI, 'original_setup'):
# TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged
def fastapi_setup(self): def fastapi_setup(self):
self.docs_url = "/docs" self.docs_url = "/docs"
self.redoc_url = "/redoc" self.redoc_url = "/redoc"
@ -327,7 +381,7 @@ def webui():
ssl_certfile=cmd_opts.tls_certfile, ssl_certfile=cmd_opts.tls_certfile,
ssl_verify=cmd_opts.disable_tls_verify, ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug, debug=cmd_opts.gradio_debug,
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, auth=gradio_auth_creds,
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True, prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path, allowed_paths=cmd_opts.gradio_allowed_path,
@ -386,47 +440,12 @@ def webui():
print('Restarting UI...') print('Restarting UI...')
shared.demo.close() shared.demo.close()
time.sleep(0.5) time.sleep(0.5)
modules.script_callbacks.app_reload_callback()
startup_timer.reset() startup_timer.reset()
modules.script_callbacks.app_reload_callback()
sd_samplers.set_samplers() startup_timer.record("app reload callback")
modules.script_callbacks.script_unloaded_callback() modules.script_callbacks.script_unloaded_callback()
extensions.list_extensions() startup_timer.record("scripts unloaded callback")
startup_timer.record("list extensions") initialize_rest(reload_script_modules=True)
restore_config_state_file()
localization.list_localizations(cmd_opts.localizations_dir)
modules.scripts.reload_scripts()
startup_timer.record("load scripts")
modules.script_callbacks.model_loaded_callback(shared.sd_model)
startup_timer.record("model loaded callback")
modelloader.load_upscalers()
startup_timer.record("load upscalers")
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")
modules.sd_models.list_models()
startup_timer.record("list SD models")
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks())
ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints())
extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("initialize extra networks")
if __name__ == "__main__": if __name__ == "__main__":