From af3f6489d3b229da4e688eaf439adb5d3e4f070b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 16:57:19 +0300 Subject: [PATCH 01/12] possibly defeat losing of focus for prompt when generating images with gallery open --- javascript/progressbar.js | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index c7d0343f..7a05726e 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -72,11 +72,17 @@ function check_gallery(id_gallery){ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { - //automatically re-open previously selected index (if exists) - activeElement = document.activeElement; + // automatically re-open previously selected index (if exists) + activeElement = gradioApp().activeElement; + galleryButtons[prevSelectedIndex].click(); showGalleryImage(); - if(activeElement) activeElement.focus() + + if(activeElement){ + // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it + // if somenoe has a better solution please by all means + setTimeout(function() { activeElement.focus() }, 1); + } } }) galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false }) From 695377a8b9f7de28b880d96487a9ddf7230cff14 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 19:56:23 +0300 Subject: [PATCH 02/12] make modelmerger work with ui-config.json --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..533b1db3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1767,6 +1767,7 @@ Requested path was: {f} visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: From cf47d13c1e11fcb7169bac7488d2c39e579ee491 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 21:15:32 +0300 Subject: [PATCH 03/12] localization support --- javascript/localization.js | 146 ++++++++++++++++++ localizations/Put localization files here.txt | 0 modules/localization.py | 31 ++++ modules/shared.py | 7 +- modules/ui.py | 33 ++-- script.js | 10 +- style.css | 2 +- 7 files changed, 211 insertions(+), 18 deletions(-) create mode 100644 javascript/localization.js create mode 100644 localizations/Put localization files here.txt create mode 100644 modules/localization.py diff --git a/javascript/localization.js b/javascript/localization.js new file mode 100644 index 00000000..e6644635 --- /dev/null +++ b/javascript/localization.js @@ -0,0 +1,146 @@ + +// localization = {} -- the dict with translations is created by the backend + +ignore_ids_for_localization={ + setting_sd_hypernetwork: 'OPTION', + setting_sd_model_checkpoint: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + modelmerger_primary_model_name: 'OPTION', + modelmerger_secondary_model_name: 'OPTION', + modelmerger_tertiary_model_name: 'OPTION', + train_embedding: 'OPTION', + train_hypernetwork: 'OPTION', + txt2img_style_index: 'OPTION', + txt2img_style2_index: 'OPTION', + img2img_style_index: 'OPTION', + img2img_style2_index: 'OPTION', + setting_random_artist_categories: 'SPAN', + setting_face_restoration_model: 'SPAN', + setting_realesrgan_enabled_models: 'SPAN', + extras_upscaler_1: 'SPAN', + extras_upscaler_2: 'SPAN', +} + +re_num = /^[\.\d]+$/ +re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u + +original_lines = {} +translated_lines = {} + +function textNodesUnder(el){ + var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); + while(n=walk.nextNode()) a.push(n); + return a; +} + +function canBeTranslated(node, text){ + if(! text) return false; + if(! node.parentElement) return false; + + parentType = node.parentElement.nodeName + if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; + + if (parentType=='OPTION' || parentType=='SPAN'){ + pnode = node + for(var level=0; level<4; level++){ + pnode = pnode.parentElement + if(! pnode) break; + + if(ignore_ids_for_localization[pnode.id] == parentType) return false; + } + } + + if(re_num.test(text)) return false; + if(re_emoji.test(text)) return false; + return true +} + +function getTranslation(text){ + if(! text) return undefined + + if(translated_lines[text] === undefined){ + original_lines[text] = 1 + } + + tl = localization[text] + if(tl !== undefined){ + translated_lines[tl] = 1 + } + + return tl +} + +function processTextNode(node){ + text = node.textContent.trim() + + if(! canBeTranslated(node, text)) return + + tl = getTranslation(text) + if(tl !== undefined){ + node.textContent = tl + } +} + +function processNode(node){ + if(node.nodeType == 3){ + processTextNode(node) + return + } + + if(node.title){ + tl = getTranslation(node.title) + if(tl !== undefined){ + node.title = tl + } + } + + if(node.placeholder){ + tl = getTranslation(node.placeholder) + if(tl !== undefined){ + node.placeholder = tl + } + } + + textNodesUnder(node).forEach(function(node){ + processTextNode(node) + }) +} + +function dumpTranslations(){ + dumped = {} + + Object.keys(original_lines).forEach(function(text){ + if(dumped[text] !== undefined) return + + dumped[text] = localization[text] || text + }) + + return dumped +} + +onUiUpdate(function(m){ + m.forEach(function(mutation){ + mutation.addedNodes.forEach(function(node){ + processNode(node) + }) + }); +}) + + +document.addEventListener("DOMContentLoaded", function() { + processNode(gradioApp()) +}) + +function download_localization() { + text = JSON.stringify(dumpTranslations(), null, 4) + + var element = document.createElement('a'); + element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); + element.setAttribute('download', "localization.json"); + element.style.display = 'none'; + document.body.appendChild(element); + + element.click(); + + document.body.removeChild(element); +} diff --git a/localizations/Put localization files here.txt b/localizations/Put localization files here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/localization.py b/modules/localization.py new file mode 100644 index 00000000..b1810cda --- /dev/null +++ b/modules/localization.py @@ -0,0 +1,31 @@ +import json +import os +import sys +import traceback + +localizations = {} + + +def list_localizations(dirname): + localizations.clear() + + for file in os.listdir(dirname): + fn, ext = os.path.splitext(file) + if ext.lower() != ".json": + continue + + localizations[fn] = os.path.join(dirname, file) + + +def localization_js(current_localization_name): + fn = localizations.get(current_localization_name, None) + data = {} + if fn is not None: + try: + with open(fn, "r", encoding="utf8") as file: + data = json.load(file) + except Exception: + print(f"Error loading localization from {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return f"var localization = {json.dumps(data)}\n" diff --git a/modules/shared.py b/modules/shared.py index c2775603..2a2b0427 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models +from modules import sd_samplers, sd_models, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -31,6 +31,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") +parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -103,7 +104,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - def reload_hypernetworks(): global hypernetworks @@ -151,6 +151,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] +localization.list_localizations(cmd_opts.localizations_dir) + def realesrgan_models_names(): import modules.realesrgan_model @@ -296,6 +298,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), + 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/modules/ui.py b/modules/ui.py index 533b1db3..656bab7a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,7 +23,7 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models +from modules import sd_hijack, sd_models, localization from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: @@ -1056,10 +1056,10 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", hoices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): @@ -1224,10 +1224,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") batch_size = gr.Number(label='Batch size', value=1, precision=0) @@ -1376,16 +1376,18 @@ def create_ui(wrap_gradio_gpu_call): else: raise Exception(f'bad options item type: {str(t)} for key {key}') + elem_id = "setting_"+key + if info.refresh is not None: if is_quicksettings: - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - res = comp(label=info.label, value=fun, **(args or {})) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) return res @@ -1509,6 +1511,9 @@ Requested path was: {f} with gr.Row(): request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') @@ -1519,6 +1524,13 @@ Requested path was: {f} _js='function(){}' ) + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1784,6 +1796,7 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" +javascript += f"\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): diff --git a/script.js b/script.js index 88f2c839..8b3b67e3 100644 --- a/script.js +++ b/script.js @@ -21,20 +21,20 @@ function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } -function runCallback(x){ +function runCallback(x, m){ try { - x() + x(m) } catch (e) { (console.error || console.log).call(console, e.message, e); } } -function executeCallbacks(queue) { - queue.forEach(runCallback) +function executeCallbacks(queue, m) { + queue.forEach(function(x){runCallback(x, m)}) } document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - executeCallbacks(uiUpdateCallbacks); + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { uiCurrentTab = newTab; diff --git a/style.css b/style.css index 71eb4d20..9dc4b696 100644 --- a/style.css +++ b/style.css @@ -478,7 +478,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name{ +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; From ab3f997c0c4a1423a82623ae1d4d3c66005bb8da Mon Sep 17 00:00:00 2001 From: Jordan Hall Date: Mon, 17 Oct 2022 20:59:44 +0100 Subject: [PATCH 04/12] Fix typo in 'choices' when loading upscaler 2 config --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 656bab7a..e4ead347 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", hoices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): From 43cb1ddad2af31170352394b81b9a299b151ea05 Mon Sep 17 00:00:00 2001 From: Adam Snodgrass Date: Mon, 17 Oct 2022 05:21:59 -0500 Subject: [PATCH 05/12] prevent highlighting/selecting image --- javascript/imageviewer.js | 1 + 1 file changed, 1 insertion(+) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index d4ab6984..9e380c65 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -116,6 +116,7 @@ function showGalleryImage() { e.dataset.modded = true; if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' + e.style.userSelect='none' e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) From d3338bdef18b3049431a0649d55ff22aa18baa68 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:46:56 +0100 Subject: [PATCH 06/12] extras extend cache key with new upscale to options --- modules/extras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 8dbab240..c908b43e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -91,7 +91,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, + resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels c = cached_images.get(key) if c is None: From 7432b6f4d2c3001895fc75411a34afae1810c1a2 Mon Sep 17 00:00:00 2001 From: Mykeehu Date: Tue, 18 Oct 2022 07:15:38 +0200 Subject: [PATCH 07/12] Fix typo "celem_id" to "elem_id" --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index e4ead347..2a7f64f9 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): From 786ed499226177d71e937e0342bcb9d3b1ff260f Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 19:48:39 +0300 Subject: [PATCH 08/12] use legacy attnblock --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 984b35c4..2407a461 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 From 2043c4a231eef838bb15044f502b864b55885037 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 19:49:11 +0300 Subject: [PATCH 09/12] delete xformers attnblock --- modules/sd_hijack_optimizations.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 79405525..60da7459 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -292,15 +292,3 @@ def cross_attention_attnblock_forward(self, x): return h3 -def xformers_attnblock_forward(self, x): - try: - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_).contiguous() - k1 = self.k(h_).contiguous() - v = self.v(h_).contiguous() - out = xformers.ops.memory_efficient_attention(q1, k1, v) - out = self.proj_out(out) - return x + out - except NotImplementedError: - return cross_attention_attnblock_forward(self, x) From 84823275e896bcc1f7cb4ce098ae3c5d05e17b9a Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:18:59 +0300 Subject: [PATCH 10/12] readd xformers attnblock --- modules/sd_hijack_optimizations.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 60da7459..7ebef3f0 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -292,3 +292,18 @@ def cross_attention_attnblock_forward(self, x): return h3 +def xformers_attnblock_forward(self, x): + try: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + out = xformers.ops.memory_efficient_attention(q, k, v) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) From 73b5dbf72a93b64445551c74a4c0dc924986081d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:19:18 +0300 Subject: [PATCH 11/12] Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2407a461..984b35c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 From c71008c74156635558bb2e877d1628913f6f781e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Tue, 18 Oct 2022 00:02:50 +0300 Subject: [PATCH 12/12] Update sd_hijack_optimizations.py --- modules/sd_hijack_optimizations.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 7ebef3f0..a3345bb9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -301,6 +301,9 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out)