From 9a33292ce41b01252cdb8ab6214a11d274e32fa0 Mon Sep 17 00:00:00 2001 From: zhengxiaoyao0716 <1499383852@qq.com> Date: Sat, 15 Oct 2022 01:04:47 +0800 Subject: [PATCH 01/23] reload javascript files when custom script bodies --- modules/ui.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index b867d40f..90b8646b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import time import traceback import platform import subprocess as sp -from functools import reduce +from functools import partial, reduce import numpy as np import torch @@ -1491,6 +1491,7 @@ Requested path was: {f} def reload_scripts(): modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page reload_script_bodies.click( fn=reload_scripts, @@ -1738,22 +1739,25 @@ Requested path was: {f} return demo -with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' +def load_javascript(raw_response): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' -jsdir = os.path.join(script_path, "javascript") -for filename in sorted(os.listdir(jsdir)): - with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: - javascript += f"\n" + jsdir = os.path.join(script_path, "javascript") + for filename in sorted(os.listdir(jsdir)): + with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: + javascript += f"\n" - -if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): - res = gradio_routes_templates_response(*args, **kwargs) - res.body = res.body.replace(b'', f'{javascript}'.encode("utf8")) + res = raw_response(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) res.init_headers() return res - gradio_routes_templates_response = gradio.routes.templates.TemplateResponse gradio.routes.templates.TemplateResponse = template_response + +reload_javascript = partial(load_javascript, + gradio.routes.templates.TemplateResponse) +reload_javascript() From 60251c9456f5472784862896c2f97e38feb42482 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 06:58:42 +0000 Subject: [PATCH 02/23] initial prototype by borrowing contracts --- modules/api/api.py | 60 +++++++++++++++++++++++++++++++++++++ modules/processing.py | 2 +- modules/shared.py | 2 +- webui.py | 69 +++++++++++++++++++++++++------------------ 4 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 modules/api/api.py diff --git a/modules/api/api.py b/modules/api/api.py new file mode 100644 index 00000000..9d7c699d --- /dev/null +++ b/modules/api/api.py @@ -0,0 +1,60 @@ +from modules.api.processing import StableDiffusionProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, process_images +import modules.shared as shared +import uvicorn +from fastapi import FastAPI, Body, APIRouter +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, Json +import json +import io +import base64 + +app = FastAPI() + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + +class Api: + def __init__(self, txt2img, img2img, run_extras, run_pnginfo): + self.router = APIRouter() + app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) + + def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + print(txt2imgreq) + p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) + p.sd_model = shared.sd_model + print(p) + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + response = { + "images": b64images, + "info": processed.js(), + "parameters": json.dumps(vars(txt2imgreq)) + } + + + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + + + + def img2imgendoint(self): + raise NotImplementedError + + def extrasendoint(self): + raise NotImplementedError + + def pnginfoendoint(self): + raise NotImplementedError + + def launch(self, server_name, port): + app.include_router(self.router) + uvicorn.run(app, host=server_name, port=port) \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index deb6125e..4a7c6ccc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -723,4 +723,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples + return samples \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index c2775603..6c6405fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) - +parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/webui.py b/webui.py index fe0ce321..cd8a99ea 100644 --- a/webui.py +++ b/webui.py @@ -97,40 +97,51 @@ def webui(): os._exit(0) signal.signal(signal.SIGINT, sigint_handler) + + if cmd_opts.api: + from modules.api.api import Api + api = Api(txt2img=modules.txt2img.txt2img, + img2img=modules.img2img.img2img, + run_extras=modules.extras.run_extras, + run_pnginfo=modules.extras.run_pnginfo) - while 1: - - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - - app, local_url, share_url = demo.launch( - share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, - server_port=cmd_opts.port, - debug=cmd_opts.gradio_debug, - auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, - inbrowser=cmd_opts.autolaunch, - prevent_thread_lock=True - ) - - app.add_middleware(GZipMiddleware, minimum_size=1000) + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + port=cmd_opts.port if cmd_opts.port else 7861) + else: while 1: - time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break - sd_samplers.set_samplers() + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') - modules.sd_models.list_models() - print('Restarting Gradio') + app, local_url, share_url = demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + while 1: + time.sleep(0.5) + if getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + sd_samplers.set_samplers() + + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Refreshing Model List') + modules.sd_models.list_models() + print('Restarting Gradio') if __name__ == "__main__": From 9e02812afd10582f00a7fbbfa63c8f9188678e26 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 07:02:08 +0000 Subject: [PATCH 03/23] pydantic instrumentation --- modules/api/processing.py | 99 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 modules/api/processing.py diff --git a/modules/api/processing.py b/modules/api/processing.py new file mode 100644 index 00000000..459a8f49 --- /dev/null +++ b/modules/api/processing.py @@ -0,0 +1,99 @@ +from inflection import underscore +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field, create_model +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +import inspect + + +class ModelDef(BaseModel): + """Assistance Class for Pydantic Dynamic Model Generation""" + + field: str + field_alias: str + field_type: Any + field_value: Any + + +class pydanticModelGenerator: + """ + Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way. + Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM + + It does not process full JSON data structures but takes simple JSON document with basic elements + + Provide a model_name, an example of JSON data and a dict of type overrides + + Example: + + source_data = {'Name': '48 Rainbow Rd', + 'GroupAddressStyle': 'ThreeLevel', + 'LastModified': '2020-12-21T07:02:51.2400232Z', + 'ProjectStart': '2020-12-03T07:36:03.324856Z', + 'Comment': '', + 'CompletionStatus': 'Editing', + 'LastUsedPuid': '955', + 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'} + + source_overrides = {'Guid':{'type':uuid.UUID}, + 'LastModified':{'type':datetime }, + 'ProjectStart':{'type':datetime }, + } + source_optionals = {"Comment":True} + + #create Model + model_Project=pydanticModelGenerator( + model_name="Project", + source_data=source_data, + overrides=source_overrides, + optionals=source_optionals).generate_model() + + #create instance using DynamicModel + project_instance=model_Project(**project_info) + + """ + + def __init__( + self, + model_name: str = None, + source_data: str = None, + params: Dict = {}, + overrides: Dict = {}, + optionals: Dict = {}, + ): + def field_type_generator(k, v, overrides, optionals): + print(k, v) + field_type = str if not overrides.get(k) else overrides[k]["type"] + if v is None: + field_type = Any + else: + field_type = type(v) + + return Optional[field_type] + + self._model_name = model_name + self._json_data = source_data + self._model_def = [ + ModelDef( + field=underscore(k), + field_alias=k, + field_type=field_type_generator(k, v, overrides, optionals), + field_value=v + ) + for (k,v) in source_data.items() if k in params + ] + + def generate_model(self): + """ + Creates a pydantic BaseModel + from the json and overrides provided at initialization + """ + fields = { + d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def + } + DynamicModel = create_model(self._model_name, **fields) + DynamicModel.__config__.allow_population_by_field_name = True + return DynamicModel + +StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", + StableDiffusionProcessing().__dict__, + inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() \ No newline at end of file From f3fe487e6340b1a2db5d2e2ddf5ae885b4eef54c Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:14:53 -0400 Subject: [PATCH 04/23] Update webui.py --- webui.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/webui.py b/webui.py index cd8a99ea..603a4ccd 100644 --- a/webui.py +++ b/webui.py @@ -100,10 +100,7 @@ def webui(): if cmd_opts.api: from modules.api.api import Api - api = Api(txt2img=modules.txt2img.txt2img, - img2img=modules.img2img.img2img, - run_extras=modules.extras.run_extras, - run_pnginfo=modules.extras.run_pnginfo) + api = Api() api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) From 832b490e5173f78c4d3aa7ca9ca9ac794d140664 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:18:41 -0400 Subject: [PATCH 05/23] Update processing.py --- modules/api/processing.py | 41 +++++---------------------------------- 1 file changed, 5 insertions(+), 36 deletions(-) diff --git a/modules/api/processing.py b/modules/api/processing.py index 459a8f49..4c3d0bd0 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -16,46 +16,15 @@ class ModelDef(BaseModel): class pydanticModelGenerator: """ - Takes source_data:Dict ( a single instance example of something like a JSON node) and self generates a pythonic data model with Alias to original source field names. This makes it easy to popuate or export to other systems yet handle the data in a pythonic way. - Being a pydantic datamodel all the richness of pydantic data validation is available and these models can easily be used in FastAPI and or a ORM - - It does not process full JSON data structures but takes simple JSON document with basic elements - - Provide a model_name, an example of JSON data and a dict of type overrides - - Example: - - source_data = {'Name': '48 Rainbow Rd', - 'GroupAddressStyle': 'ThreeLevel', - 'LastModified': '2020-12-21T07:02:51.2400232Z', - 'ProjectStart': '2020-12-03T07:36:03.324856Z', - 'Comment': '', - 'CompletionStatus': 'Editing', - 'LastUsedPuid': '955', - 'Guid': '0c85957b-c2ae-4985-9752-b300ab385b36'} - - source_overrides = {'Guid':{'type':uuid.UUID}, - 'LastModified':{'type':datetime }, - 'ProjectStart':{'type':datetime }, - } - source_optionals = {"Comment":True} - - #create Model - model_Project=pydanticModelGenerator( - model_name="Project", - source_data=source_data, - overrides=source_overrides, - optionals=source_optionals).generate_model() - - #create instance using DynamicModel - project_instance=model_Project(**project_info) - + Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: + source_data is a snapshot of the default values produced by the class + params are the names of the actual keys required by __init__ """ def __init__( self, model_name: str = None, - source_data: str = None, + source_data: {} = {}, params: Dict = {}, overrides: Dict = {}, optionals: Dict = {}, @@ -96,4 +65,4 @@ class pydanticModelGenerator: StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", StableDiffusionProcessing().__dict__, - inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() \ No newline at end of file + inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() From 99013ba68a5fe1bde3621632e5539c03562a3ae8 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:20:17 -0400 Subject: [PATCH 06/23] Update processing.py --- modules/api/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/processing.py b/modules/api/processing.py index 4c3d0bd0..e4df93c5 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -30,7 +30,6 @@ class pydanticModelGenerator: optionals: Dict = {}, ): def field_type_generator(k, v, overrides, optionals): - print(k, v) field_type = str if not overrides.get(k) else overrides[k]["type"] if v is None: field_type = Any From 71d42bb44b257f3fb274c3ad5075a195281ff915 Mon Sep 17 00:00:00 2001 From: Jonathan Date: Mon, 17 Oct 2022 03:22:19 -0400 Subject: [PATCH 07/23] Update api.py --- modules/api/api.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9d7c699d..4d9619a8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,10 +23,8 @@ class Api: app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): - print(txt2imgreq) p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) p.sd_model = shared.sd_model - print(p) processed = process_images(p) b64images = [] @@ -34,13 +32,6 @@ class Api: buffer = io.BytesIO() i.save(buffer, format="png") b64images.append(base64.b64encode(buffer.getvalue())) - - response = { - "images": b64images, - "info": processed.js(), - "parameters": json.dumps(vars(txt2imgreq)) - } - return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) @@ -57,4 +48,4 @@ class Api: def launch(self, server_name, port): app.include_router(self.router) - uvicorn.run(app, host=server_name, port=port) \ No newline at end of file + uvicorn.run(app, host=server_name, port=port) From 964b63c0423a861bd67c40b59f767e7037051083 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 11:38:32 +0300 Subject: [PATCH 08/23] add api() function to return webui() to how it was --- webui.py | 85 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/webui.py b/webui.py index 603a4ccd..16c862f0 100644 --- a/webui.py +++ b/webui.py @@ -87,59 +87,62 @@ def initialize(): shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) - -def webui(): - initialize() - # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') os._exit(0) signal.signal(signal.SIGINT, sigint_handler) - - if cmd_opts.api: - from modules.api.api import Api - api = Api() - api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", - port=cmd_opts.port if cmd_opts.port else 7861) - else: +def api() + initialize() + + from modules.api.api import Api + api = Api() + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + + +def webui(): + initialize() + + while 1: + + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + + app, local_url, share_url = demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + while 1: - - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - - app, local_url, share_url = demo.launch( - share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, - server_port=cmd_opts.port, - debug=cmd_opts.gradio_debug, - auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, - inbrowser=cmd_opts.autolaunch, - prevent_thread_lock=True - ) - - app.add_middleware(GZipMiddleware, minimum_size=1000) - - while 1: + time.sleep(0.5) + if getattr(demo, 'do_restart', False): time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break + demo.close() + time.sleep(0.5) + break - sd_samplers.set_samplers() + sd_samplers.set_samplers() - print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') - modules.sd_models.list_models() - print('Restarting Gradio') + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Refreshing Model List') + modules.sd_models.list_models() + print('Restarting Gradio') if __name__ == "__main__": - webui() + if cmd_opts.api: + api() + else: + webui() From d42125baf62880854ad06af06c15c23e7e50cca6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 11:50:20 +0300 Subject: [PATCH 09/23] add missing requirement for api and fix some typos --- modules/api/api.py | 2 +- requirements.txt | 1 + requirements_versions.txt | 1 + webui.py | 2 +- 4 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 4d9619a8..fd09d352 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self, txt2img, img2img, run_extras, run_pnginfo): + def __init__(self): self.router = APIRouter() app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) diff --git a/requirements.txt b/requirements.txt index cf583de9..da1969cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ resize-right torchdiffeq kornia lark +inflection diff --git a/requirements_versions.txt b/requirements_versions.txt index abadcb58..72ccc5a3 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,3 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 +inflection==0.5.1 diff --git a/webui.py b/webui.py index 16c862f0..eeee44c3 100644 --- a/webui.py +++ b/webui.py @@ -95,7 +95,7 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) -def api() +def api(): initialize() from modules.api.api import Api From f80e914ac4aa69a9783b4040813253500b34d925 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 19:10:36 +0000 Subject: [PATCH 10/23] example API working with gradio --- modules/api/api.py | 9 +++++-- modules/api/processing.py | 56 ++++++++++++++++++++++++++------------- modules/processing.py | 22 ++++++++++----- 3 files changed, 60 insertions(+), 27 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index fd09d352..5e86c3bf 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -23,8 +23,13 @@ class Api: app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): - p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) - p.sd_model = shared.sd_model + populate = txt2imgreq.copy(update={ # Override __init__ params + "sd_model": shared.sd_model, + "sampler_index": 0, + } + ) + p = StableDiffusionProcessingTxt2Img(**vars(populate)) + # Override object param processed = process_images(p) b64images = [] diff --git a/modules/api/processing.py b/modules/api/processing.py index e4df93c5..b6798241 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu import inspect +API_NOT_ALLOWED = [ + "self", + "kwargs", + "sd_model", + "outpath_samples", + "outpath_grids", + "sampler_index", + "do_not_save_samples", + "do_not_save_grid", + "extra_generation_params", + "overlay_images", + "do_not_reload_embeddings", + "seed_enable_extras", + "prompt_for_display", + "sampler_noise_scheduler_override", + "ddim_discretize" +] + class ModelDef(BaseModel): """Assistance Class for Pydantic Dynamic Model Generation""" @@ -14,7 +32,7 @@ class ModelDef(BaseModel): field_value: Any -class pydanticModelGenerator: +class PydanticModelGenerator: """ Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: source_data is a snapshot of the default values produced by the class @@ -24,30 +42,33 @@ class pydanticModelGenerator: def __init__( self, model_name: str = None, - source_data: {} = {}, - params: Dict = {}, - overrides: Dict = {}, - optionals: Dict = {}, + class_instance = None ): - def field_type_generator(k, v, overrides, optionals): - field_type = str if not overrides.get(k) else overrides[k]["type"] - if v is None: - field_type = Any - else: - field_type = type(v) + def field_type_generator(k, v): + # field_type = str if not overrides.get(k) else overrides[k]["type"] + # print(k, v.annotation, v.default) + field_type = v.annotation return Optional[field_type] + def merge_class_params(class_): + all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_))) + parameters = {} + for classes in all_classes: + parameters = {**parameters, **inspect.signature(classes.__init__).parameters} + return parameters + + self._model_name = model_name - self._json_data = source_data + self._class_data = merge_class_params(class_instance) self._model_def = [ ModelDef( field=underscore(k), field_alias=k, - field_type=field_type_generator(k, v, overrides, optionals), - field_value=v + field_type=field_type_generator(k, v), + field_value=v.default ) - for (k,v) in source_data.items() if k in params + for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] def generate_model(self): @@ -60,8 +81,7 @@ class pydanticModelGenerator: } DynamicModel = create_model(self._model_name, **fields) DynamicModel.__config__.allow_population_by_field_name = True + DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", - StableDiffusionProcessing().__dict__, - inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model() +StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() diff --git a/modules/processing.py b/modules/processing.py index 4a7c6ccc..024a4fc3 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps import random import cv2 from skimage import exposure +from typing import Any, Dict, List, Optional import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram @@ -51,9 +52,15 @@ def get_correct_sampler(p): return sd_samplers.samplers elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI): + return sd_samplers.samplers -class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): +class StableDiffusionProcessing(): + """ + The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing + + """ + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -86,10 +93,10 @@ class StableDiffusionProcessing: self.denoising_strength: float = 0 self.sampler_noise_scheduler_override = None self.ddim_discretize = opts.ddim_discretize - self.s_churn = opts.s_churn - self.s_tmin = opts.s_tmin - self.s_tmax = float('inf') # not representable as a standard ui option - self.s_noise = opts.s_noise + self.s_churn = s_churn or opts.s_churn + self.s_tmin = s_tmin or opts.s_tmin + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option + self.s_noise = s_noise or opts.s_noise if not seed_enable_extras: self.subseed = -1 @@ -97,6 +104,7 @@ class StableDiffusionProcessing: self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): + def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength From f29b16bad19b6332a15b2ef439864d866277fffb Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 20:36:14 +0000 Subject: [PATCH 11/23] prevent API from saving --- modules/api/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 5e86c3bf..ce72c5ee 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -26,6 +26,8 @@ class Api: populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, "sampler_index": 0, + "do_not_save_samples": True, + "do_not_save_grid": True } ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) From c3851a853d99ad35ccedcdd8dbeb6cfbe273439b Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:49:33 -0400 Subject: [PATCH 12/23] Re-use webui fastapi application rather than requiring one or the other, not both. --- modules/api/api.py | 6 ++---- webui.py | 14 +++++++------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index ce72c5ee..8781cd86 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -2,15 +2,13 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images import modules.shared as shared import uvicorn -from fastapi import FastAPI, Body, APIRouter +from fastapi import Body, APIRouter from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json import json import io import base64 -app = FastAPI() - class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json @@ -18,7 +16,7 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self): + def __init__(self, app): self.router = APIRouter() app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) diff --git a/webui.py b/webui.py index eeee44c3..6b55fbed 100644 --- a/webui.py +++ b/webui.py @@ -96,14 +96,11 @@ def initialize(): def api(): - initialize() - from modules.api.api import Api - api = Api() - api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) + api = Api(app) -def webui(): +def webui(launch_api=False): initialize() while 1: @@ -122,6 +119,9 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) + if (launch_api): + api(app) + while 1: time.sleep(0.5) if getattr(demo, 'do_restart', False): @@ -143,6 +143,6 @@ def webui(): if __name__ == "__main__": if cmd_opts.api: - api() + webui(True) else: - webui() + webui(False) From 247aeb3aaaf2925c7d68a9cf47c975f3e6d3dd33 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:50:45 -0400 Subject: [PATCH 13/23] Put API under /sdapi/ so that routing is simpler in the future. This means that one could allow access to /sdapi/ but not the webui. --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 8781cd86..14613d8c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ class TextToImageResponse(BaseModel): class Api: def __init__(self, app): self.router = APIRouter() - app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) + app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): populate = txt2imgreq.copy(update={ # Override __init__ params From 1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:58:34 -0400 Subject: [PATCH 14/23] Add --nowebui as a means of disabling the webui and run on the other port --- modules/shared.py | 3 ++- webui.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 6c6405fd..8b436970 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/webui.py b/webui.py index 6b55fbed..6212be01 100644 --- a/webui.py +++ b/webui.py @@ -95,16 +95,34 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) -def api(): +def create_api(app): from modules.api.api import Api api = Api(app) + return api + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if demo and getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + +def api_only(): + initialize() + + app = FastAPI() + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) def webui(launch_api=False): initialize() while 1: - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) app, local_url, share_url = demo.launch( @@ -120,15 +138,9 @@ def webui(launch_api=False): app.add_middleware(GZipMiddleware, minimum_size=1000) if (launch_api): - api(app) + create_api(app) - while 1: - time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break + wait_on_server(demo) sd_samplers.set_samplers() @@ -142,6 +154,9 @@ def webui(launch_api=False): if __name__ == "__main__": + if not cmd_opts.nowebui: + api_only() + if cmd_opts.api: webui(True) else: From 8d5d863a9d11850464fdb6b64f34602803c15ccc Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Tue, 18 Oct 2022 06:51:53 +0000 Subject: [PATCH 15/23] gradio and FastAPI --- modules/api/api.py | 13 ++++++++----- webui.py | 18 ++++++++---------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 14613d8c..ce98cb8c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel): class Api: - def __init__(self, app): + def __init__(self, app, queue_lock): self.router = APIRouter() - app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) + self.app = app + self.queue_lock = queue_lock + self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): populate = txt2imgreq.copy(update={ # Override __init__ params @@ -30,7 +32,8 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - processed = process_images(p) + with self.queue_lock: + processed = process_images(p) b64images = [] for i in processed.images: @@ -52,5 +55,5 @@ class Api: raise NotImplementedError def launch(self, server_name, port): - app.include_router(self.router) - uvicorn.run(app, host=server_name, port=port) + self.app.include_router(self.router) + uvicorn.run(self.app, host=server_name, port=port) diff --git a/webui.py b/webui.py index 6212be01..71724c3b 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import time import importlib import signal import threading - +from fastapi import FastAPI from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path @@ -31,7 +31,6 @@ from modules.paths import script_path from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork - queue_lock = threading.Lock() @@ -97,7 +96,7 @@ def initialize(): def create_api(app): from modules.api.api import Api - api = Api(app) + api = Api(app, queue_lock) return api def wait_on_server(demo=None): @@ -141,7 +140,7 @@ def webui(launch_api=False): create_api(app) wait_on_server(demo) - + sd_samplers.set_samplers() print('Reloading Custom Scripts') @@ -153,11 +152,10 @@ def webui(launch_api=False): print('Restarting Gradio') -if __name__ == "__main__": - if not cmd_opts.nowebui: - api_only() - if cmd_opts.api: - webui(True) +task = [] +if __name__ == "__main__": + if cmd_opts.nowebui: + api_only() else: - webui(False) + webui(cmd_opts.api) \ No newline at end of file From e7f4808505f7a6339927c32b9a0c01bc9134bdeb Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Tue, 18 Oct 2022 19:04:56 +0000 Subject: [PATCH 16/23] provide sampler by name --- modules/api/api.py | 12 ++++++++++-- modules/api/processing.py | 16 ++++++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index ce98cb8c..ff9df0d1 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,14 +1,17 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images +from modules.sd_samplers import samplers_k_diffusion import modules.shared as shared import uvicorn -from fastapi import Body, APIRouter +from fastapi import Body, APIRouter, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel, Field, Json import json import io import base64 +sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None) + class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") parameters: Json @@ -23,9 +26,14 @@ class Api: self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + sampler_index = sampler_to_index(txt2imgreq.sampler_index) + + if sampler_index is None: + raise HTTPException(status_code=404, detail="Sampler not found") + populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": 0, + "sampler_index": sampler_index[0], "do_not_save_samples": True, "do_not_save_grid": True } diff --git a/modules/api/processing.py b/modules/api/processing.py index b6798241..2e6483ee 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -42,7 +42,8 @@ class PydanticModelGenerator: def __init__( self, model_name: str = None, - class_instance = None + class_instance = None, + additional_fields = None, ): def field_type_generator(k, v): # field_type = str if not overrides.get(k) else overrides[k]["type"] @@ -70,6 +71,13 @@ class PydanticModelGenerator: ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] + + for fields in additional_fields: + self._model_def.append(ModelDef( + field=underscore(fields["key"]), + field_alias=fields["key"], + field_type=fields["type"], + field_value=fields["default"])) def generate_model(self): """ @@ -84,4 +92,8 @@ class PydanticModelGenerator: DynamicModel.__config__.allow_mutation = True return DynamicModel -StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() +StableDiffusionProcessingAPI = PydanticModelGenerator( + "StableDiffusionProcessingTxt2Img", + StableDiffusionProcessingTxt2Img, + [{"key": "sampler_index", "type": str, "default": "k_euler_a"}] +).generate_model() From 0f0d6ab8e06898ce066251fc769fe14e77e98ced Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Wed, 19 Oct 2022 05:19:01 +0000 Subject: [PATCH 17/23] call sampler by name --- modules/api/api.py | 11 ++++++----- modules/api/processing.py | 6 +++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index ff9df0d1..5b0c934e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,6 +1,7 @@ from modules.api.processing import StableDiffusionProcessingAPI from modules.processing import StableDiffusionProcessingTxt2Img, process_images -from modules.sd_samplers import samplers_k_diffusion +from modules.sd_samplers import all_samplers +from modules.extras import run_pnginfo import modules.shared as shared import uvicorn from fastapi import Body, APIRouter, HTTPException @@ -10,7 +11,7 @@ import json import io import base64 -sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None) +sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) class TextToImageResponse(BaseModel): images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") @@ -53,13 +54,13 @@ class Api: - def img2imgendoint(self): + def img2imgapi(self): raise NotImplementedError - def extrasendoint(self): + def extrasapi(self): raise NotImplementedError - def pnginfoendoint(self): + def pnginfoapi(self): raise NotImplementedError def launch(self, server_name, port): diff --git a/modules/api/processing.py b/modules/api/processing.py index 2e6483ee..4c541241 100644 --- a/modules/api/processing.py +++ b/modules/api/processing.py @@ -1,7 +1,7 @@ from inflection import underscore from typing import Any, Dict, Optional from pydantic import BaseModel, Field, create_model -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessingTxt2Img import inspect @@ -95,5 +95,5 @@ class PydanticModelGenerator: StableDiffusionProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "k_euler_a"}] -).generate_model() + [{"key": "sampler_index", "type": str, "default": "Euler"}] +).generate_model() \ No newline at end of file From 10aca1ca3e81e69e08f556a500c3dc603451429b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 19 Oct 2022 08:42:22 +0300 Subject: [PATCH 18/23] more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names) --- modules/sd_models.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..7ad6d474 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -122,11 +122,33 @@ def select_checkpoint(): return checkpoint_info +chckpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', +} + + +def transform_checkpoint_dict_key(k): + for text, replacement in chckpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + + return k + + def get_state_dict_from_checkpoint(pl_sd): if "state_dict" in pl_sd: - return pl_sd["state_dict"] + pl_sd = pl_sd["state_dict"] - return pl_sd + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + + if new_key is not None: + sd[new_key] = v + + return sd def load_model_weights(model, checkpoint_info): @@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info): print(f"Global Step: {pl_sd['global_step']}") sd = get_state_dict_from_checkpoint(pl_sd) - model.load_state_dict(sd, strict=False) + missing, extra = model.load_state_dict(sd, strict=False) if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) From da72becb13e4b750fbcb3d158c3f843311ef9938 Mon Sep 17 00:00:00 2001 From: Silent <16026653+s-ilent@users.noreply.github.com> Date: Wed, 19 Oct 2022 16:14:33 +1030 Subject: [PATCH 19/23] Use training width/height when training hypernetworks. --- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/ui.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4905710e..b8695fc1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -196,7 +196,7 @@ def stack_conds(conds): return torch.stack(conds) -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert hypernetwork_name, 'hypernetwork not selected' path = shared.hypernetworks.get(hypernetwork_name, None) @@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) diff --git a/modules/ui.py b/modules/ui.py index fb6eb5a0..ca46343f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1341,6 +1341,8 @@ def create_ui(wrap_gradio_gpu_call): batch_size, dataset_directory, log_directory, + training_width, + training_height, steps, create_image_every, save_embedding_every, From 2fd7935ef4ed296db5dfd8c7fea99244816f8cf0 Mon Sep 17 00:00:00 2001 From: Cheka Date: Tue, 18 Oct 2022 20:28:28 -0300 Subject: [PATCH 20/23] Remove wrong self reference in CUDA support for invokeai --- modules/sd_hijack_optimizations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index a3345bb9..98123fbf 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) def einsum_op(q, k, v): if q.device.type == 'cuda': From bcfbb33e50a48b237d8d961cc2be038db53774d5 Mon Sep 17 00:00:00 2001 From: Anastasius Date: Mon, 17 Oct 2022 13:35:20 -0700 Subject: [PATCH 21/23] Added time left estimation --- modules/ui.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index ca46343f..9a54aa16 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,6 +261,15 @@ def wrap_gradio_call(func, extra_outputs=None): return f +def calc_time_left(progress): + if progress == 0: + return "N/A" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start)) + + def check_progress_call(id_part): if shared.state.job_count == 0: return "", gr_show(False), gr_show(False), gr_show(False) @@ -272,11 +281,13 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + time_left = calc_time_left( progress ) + progress = min(progress, 1) progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"%" if progress > 0.01 else ""}
""" + progressbar = f"""
{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) @@ -308,6 +319,7 @@ def check_progress_call_initial(id_part): shared.state.current_latent = None shared.state.current_image = None shared.state.textinfo = None + shared.state.time_start = time.time() return check_progress_call(id_part) From 442dbedc159bb7e9cf94f0c3626f8a409e0a50eb Mon Sep 17 00:00:00 2001 From: Anastasius Date: Tue, 18 Oct 2022 10:38:07 -0700 Subject: [PATCH 22/23] Estimated time displayed if jobs take more 60 sec --- modules/ui.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9a54aa16..fa54110b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,13 +261,17 @@ def wrap_gradio_call(func, extra_outputs=None): return f -def calc_time_left(progress): +def calc_time_left(progress, threshold, label, force_display): if progress == 0: - return "N/A" + return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) - return time.strftime('%H:%M:%S', time.gmtime(eta-time_since_start)) + eta_relative = eta-time_since_start + if eta_relative > threshold or force_display: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + else: + return "" def check_progress_call(id_part): @@ -281,13 +285,15 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress ) + time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display ) + if time_left != "": + shared.state.time_left_force_display = True progress = min(progress, 1) progressbar = "" if opts.show_progressbar: - progressbar = f"""
{str(int(progress*100))+"% ETA:"+time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) @@ -320,6 +326,7 @@ def check_progress_call_initial(id_part): shared.state.current_image = None shared.state.textinfo = None shared.state.time_start = time.time() + shared.state.time_left_force_display = False return check_progress_call(id_part) From 1d4aa376e6111e90888a30ae24d2bcd7f978ec51 Mon Sep 17 00:00:00 2001 From: Anastasius Date: Tue, 18 Oct 2022 12:42:39 -0700 Subject: [PATCH 23/23] Predictable long operation check for time estimation --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index fa54110b..38ba1138 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -268,7 +268,7 @@ def calc_time_left(progress, threshold, label, force_display): time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if eta_relative > threshold or force_display: + if (eta_relative > threshold and progress > 0.02) or force_display: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) else: return ""