make existing script loading and new preload code use same code for loading modules

limit extension preload scripts to just one file named preload.py
This commit is contained in:
AUTOMATIC 2022-11-12 10:56:06 +03:00
parent e5690d0bf2
commit a1a376331c
4 changed files with 51 additions and 51 deletions

View File

@ -1,7 +1,6 @@
import os
import sys
import traceback
from importlib.machinery import SourceFileLoader
import git
@ -85,23 +84,3 @@ def list_extensions():
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
def preload_extensions(parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
for file in os.listdir(path):
if "preload.py" in file:
full_file = os.path.join(path, file)
print(f"Got preload file: {full_file}")
try:
ext = SourceFileLoader("preload", full_file).load_module()
parser = ext.preload(parser)
except Exception as e:
print(f"Exception preloading script: {e}")
return parser

34
modules/script_loading.py Normal file
View File

@ -0,0 +1,34 @@
import os
import sys
import traceback
from types import ModuleType
def load_module(path):
with open(path, "r", encoding="utf8") as file:
text = file.read()
compiled = compile(text, path, 'exec')
module = ModuleType(os.path.basename(path))
exec(compiled, module.__dict__)
return module
def preload_extensions(extensions_dir, parser):
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script):
continue
try:
module = load_module(preload_script)
if hasattr(module, 'preload'):
module.preload(parser)
except Exception:
print(f"Error running preload() for {preload_script}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)

View File

@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions
from modules import shared, paths, script_callbacks, extensions, script_loading
AlwaysVisible = object()
@ -161,13 +161,7 @@ def load_scripts():
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
module = script_loading.load_module(scriptfile.path)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
@ -328,27 +322,21 @@ class ScriptRunner:
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
args_to = script.args_to
filename = script.filename
text = file.read()
args_from = script.args_from
args_to = script.args_to
filename = script.filename
from types import ModuleType
module = cache.get(filename, None)
if module is None:
module = script_loading.load_module(script.filename)
cache[filename] = module
module = cache.get(filename, None)
if module is None:
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
self.scripts[si] = script_class()
self.scripts[si].filename = filename
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()

View File

@ -3,7 +3,6 @@ import datetime
import json
import os
import sys
from collections import OrderedDict
import time
import gradio as gr
@ -15,7 +14,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path
@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
extensions.preload_extensions(parser)
script_loading.preload_extensions(extensions.extensions_dir, parser)
cmd_opts = parser.parse_args()