Merge branch 'AUTOMATIC1111:master' into small-touch-up
This commit is contained in:
commit
e9d7eff70a
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
2
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
@ -1,7 +1,7 @@
|
||||
name: Feature request
|
||||
description: Suggest an idea for this project
|
||||
title: "[Feature Request]: "
|
||||
labels: ["suggestion"]
|
||||
labels: ["enhancement"]
|
||||
|
||||
body:
|
||||
- type: checkboxes
|
||||
|
10
README.md
10
README.md
@ -1,9 +1,7 @@
|
||||
# Stable Diffusion web UI
|
||||
A browser interface based on Gradio library for Stable Diffusion.
|
||||
|
||||
![](txt2img_Screenshot.png)
|
||||
|
||||
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
|
||||
![](screenshot.png)
|
||||
|
||||
## Features
|
||||
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
|
||||
@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab):
|
||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
|
||||
2. Install [git](https://git-scm.com/download/win).
|
||||
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
||||
4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
||||
5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
||||
6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||
4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it).
|
||||
5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||
|
||||
### Automatic Installation on Linux
|
||||
1. Install the dependencies:
|
||||
@ -141,6 +138,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
||||
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion
|
||||
- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
|
||||
- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
|
||||
- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention)
|
||||
- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
|
||||
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
|
||||
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
|
||||
|
@ -184,7 +184,7 @@ SOFTWARE.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/JingyunLiang/SwinIR/blob/main/LICENSE">SwinIR</a></h2>
|
||||
<small>Code added by contirubtors, most likely copied from this repository.</small>
|
||||
<small>Code added by contributors, most likely copied from this repository.</small>
|
||||
|
||||
<pre>
|
||||
Apache License
|
||||
@ -390,3 +390,30 @@ SOFTWARE.
|
||||
limitations under the License.
|
||||
</pre>
|
||||
|
||||
<h2><a href="https://github.com/AminRezaei0x443/memory-efficient-attention/blob/main/LICENSE">Memory Efficient Attention</a></h2>
|
||||
<small>The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that.</small>
|
||||
<pre>
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Alex Birch
|
||||
Copyright (c) 2023 Amin Rezaei
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
</pre>
|
||||
|
||||
|
@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from secrets import compare_digest
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images
|
||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
|
||||
from modules.api.models import *
|
||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||
from modules.extras import run_extras
|
||||
@ -28,8 +28,13 @@ def upscaler_to_index(name: str):
|
||||
try:
|
||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
||||
|
||||
def script_name_to_index(name, scripts):
|
||||
try:
|
||||
return [script.title().lower() for script in scripts].index(name.lower())
|
||||
except:
|
||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
||||
|
||||
def validate_sampler_name(name):
|
||||
config = sd_samplers.all_samplers_map.get(name, None)
|
||||
@ -143,7 +148,21 @@ class Api:
|
||||
|
||||
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"})
|
||||
|
||||
def get_script(self, script_name, script_runner):
|
||||
if script_name is None:
|
||||
return None, None
|
||||
|
||||
if not script_runner.scripts:
|
||||
script_runner.initialize_scripts(False)
|
||||
ui.create_ui()
|
||||
|
||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||
script = script_runner.selectable_scripts[script_idx]
|
||||
return script, script_idx
|
||||
|
||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||
script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img)
|
||||
|
||||
populate = txt2imgreq.copy(update={ # Override __init__ params
|
||||
"sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index),
|
||||
"do_not_save_samples": True,
|
||||
@ -153,14 +172,22 @@ class Api:
|
||||
if populate.sampler_name:
|
||||
populate.sampler_index = None # prevent a warning later on
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('script_name', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate))
|
||||
p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args)
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
if script is not None:
|
||||
p.outpath_grids = opts.outdir_txt2img_grids
|
||||
p.outpath_samples = opts.outdir_txt2img_samples
|
||||
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
||||
processed = scripts.scripts_txt2img.run(p, *p.script_args)
|
||||
else:
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||
@ -170,6 +197,8 @@ class Api:
|
||||
if init_images is None:
|
||||
raise HTTPException(status_code=404, detail="Init image not found")
|
||||
|
||||
script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img)
|
||||
|
||||
mask = img2imgreq.mask
|
||||
if mask:
|
||||
mask = decode_base64_to_image(mask)
|
||||
@ -186,13 +215,20 @@ class Api:
|
||||
|
||||
args = vars(populate)
|
||||
args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine.
|
||||
args.pop('script_name', None)
|
||||
|
||||
with self.queue_lock:
|
||||
p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)
|
||||
p.init_images = [decode_base64_to_image(x) for x in init_images]
|
||||
|
||||
shared.state.begin()
|
||||
processed = process_images(p)
|
||||
if script is not None:
|
||||
p.outpath_grids = opts.outdir_img2img_grids
|
||||
p.outpath_samples = opts.outdir_img2img_samples
|
||||
p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args
|
||||
processed = scripts.scripts_img2img.run(p, *p.script_args)
|
||||
else:
|
||||
processed = process_images(p)
|
||||
shared.state.end()
|
||||
|
||||
b64images = list(map(encode_pil_to_base64, processed.images))
|
||||
|
@ -100,13 +100,13 @@ class PydanticModelGenerator:
|
||||
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingTxt2Img",
|
||||
StableDiffusionProcessingTxt2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}]
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||
).generate_model()
|
||||
|
||||
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
|
||||
"StableDiffusionProcessingImg2Img",
|
||||
StableDiffusionProcessingImg2Img,
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
|
||||
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}]
|
||||
).generate_model()
|
||||
|
||||
class TextToImageResponse(BaseModel):
|
||||
@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel):
|
||||
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
|
||||
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
|
||||
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.")
|
||||
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
|
||||
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the chosen size?")
|
||||
|
@ -98,7 +98,7 @@ class StableDiffusionProcessing():
|
||||
"""
|
||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||
"""
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None):
|
||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||
if sampler_index is not None:
|
||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||
|
||||
@ -149,7 +149,7 @@ class StableDiffusionProcessing():
|
||||
self.seed_resize_from_w = 0
|
||||
|
||||
self.scripts = None
|
||||
self.script_args = None
|
||||
self.script_args = script_args
|
||||
self.all_prompts = None
|
||||
self.all_negative_prompts = None
|
||||
self.all_seeds = None
|
||||
|
@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork
|
||||
from modules.shared import cmd_opts
|
||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||
|
||||
from modules.sd_hijack_optimizations import invokeAI_mps_available
|
||||
|
||||
import ldm.modules.attention
|
||||
import ldm.modules.diffusionmodules.model
|
||||
import ldm.modules.diffusionmodules.openaimodel
|
||||
@ -43,20 +41,19 @@ def apply_optimizations():
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
|
||||
optimization_method = 'xformers'
|
||||
elif cmd_opts.opt_sub_quad_attention:
|
||||
print("Applying sub-quadratic cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|
||||
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
|
||||
optimization_method = 'sub-quadratic'
|
||||
elif cmd_opts.opt_split_attention_v1:
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
|
||||
if not invokeAI_mps_available and shared.device.type == 'mps':
|
||||
print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
|
||||
print("Applying v1 cross attention optimization.")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
|
||||
optimization_method = 'V1'
|
||||
else:
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (InvokeAI).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
|
||||
optimization_method = 'InvokeAI'
|
||||
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
|
||||
print("Applying cross attention optimization (Doggettx).")
|
||||
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|
||||
@ -86,10 +83,12 @@ class StableDiffusionModelHijack:
|
||||
clip = None
|
||||
optimization_method = None
|
||||
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
|
||||
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
|
||||
|
||||
def __init__(self):
|
||||
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
|
||||
|
||||
def hijack(self, m):
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
model_embeddings = m.cond_stage_model.roberta.embeddings
|
||||
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
|
||||
@ -120,7 +119,6 @@ class StableDiffusionModelHijack:
|
||||
self.layers = flatten(m)
|
||||
|
||||
def undo_hijack(self, m):
|
||||
|
||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||
|
||||
|
@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
|
||||
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
|
||||
original_mean = z.mean()
|
||||
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
|
||||
new_mean = z.mean()
|
||||
z *= original_mean / new_mean
|
||||
z = z * (original_mean / new_mean)
|
||||
|
||||
return z
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
import sys
|
||||
import traceback
|
||||
import importlib
|
||||
import psutil
|
||||
|
||||
import torch
|
||||
from torch import einsum
|
||||
@ -12,6 +12,8 @@ from einops import rearrange
|
||||
from modules import shared
|
||||
from modules.hypernetworks import hypernetwork
|
||||
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
|
||||
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
try:
|
||||
@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
|
||||
def get_available_vram():
|
||||
if shared.device.type == 'cuda':
|
||||
stats = torch.cuda.memory_stats(shared.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
else:
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
|
||||
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
|
||||
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
|
||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||
@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
||||
return self.to_out(r2)
|
||||
|
||||
|
||||
def check_for_psutil():
|
||||
try:
|
||||
spec = importlib.util.find_spec('psutil')
|
||||
return spec is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
invokeAI_mps_available = check_for_psutil()
|
||||
|
||||
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
|
||||
if invokeAI_mps_available:
|
||||
import psutil
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
@ -215,6 +214,71 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
||||
|
||||
# -- End of code from https://github.com/invoke-ai/InvokeAI --
|
||||
|
||||
|
||||
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
|
||||
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
|
||||
def sub_quad_attention_forward(self, x, context=None, mask=None):
|
||||
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
|
||||
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
|
||||
k = self.to_k(context_k)
|
||||
v = self.to_v(context_v)
|
||||
del context, context_k, context_v, x
|
||||
|
||||
q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
|
||||
x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
|
||||
|
||||
out_proj, dropout = self.to_out
|
||||
x = out_proj(x)
|
||||
x = dropout(x)
|
||||
|
||||
return x
|
||||
|
||||
def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):
|
||||
bytes_per_token = torch.finfo(q.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = q.shape
|
||||
_, k_tokens, _ = k.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
if chunk_threshold is None:
|
||||
chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
|
||||
elif chunk_threshold == 0:
|
||||
chunk_threshold_bytes = None
|
||||
else:
|
||||
chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram())
|
||||
|
||||
if kv_chunk_size_min is None and chunk_threshold_bytes is not None:
|
||||
kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2]))
|
||||
elif kv_chunk_size_min == 0:
|
||||
kv_chunk_size_min = None
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
|
||||
return efficient_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_chunk_size=q_chunk_size,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min = kv_chunk_size_min,
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def xformers_attention_forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q_in = self.to_q(x)
|
||||
@ -252,12 +316,7 @@ def cross_attention_attnblock_forward(self, x):
|
||||
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
mem_free_total = get_available_vram()
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
mem_required = tensor_size * 2.5
|
||||
@ -312,3 +371,19 @@ def xformers_attnblock_forward(self, x):
|
||||
return x + out
|
||||
except NotImplementedError:
|
||||
return cross_attention_attnblock_forward(self, x)
|
||||
|
||||
def sub_quad_attnblock_forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
|
||||
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
|
||||
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
|
||||
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
|
||||
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
|
||||
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
|
||||
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
|
||||
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
|
||||
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
|
||||
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
|
||||
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
|
||||
|
205
modules/sub_quadratic_attention.py
Normal file
205
modules/sub_quadratic_attention.py
Normal file
@ -0,0 +1,205 @@
|
||||
# original source:
|
||||
# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py
|
||||
# license:
|
||||
# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)
|
||||
# credit:
|
||||
# Amin Rezaei (original author)
|
||||
# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
|
||||
# brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
|
||||
# implementation of:
|
||||
# Self-attention Does Not Need O(n2) Memory":
|
||||
# https://arxiv.org/abs/2112.05682v2
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
import math
|
||||
from typing import Optional, NamedTuple, Protocol, List
|
||||
|
||||
def narrow_trunc(
|
||||
input: Tensor,
|
||||
dim: int,
|
||||
start: int,
|
||||
length: int
|
||||
) -> Tensor:
|
||||
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
|
||||
|
||||
class AttnChunk(NamedTuple):
|
||||
exp_values: Tensor
|
||||
exp_weights_sum: Tensor
|
||||
max_score: Tensor
|
||||
|
||||
class SummarizeChunk(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
) -> AttnChunk: ...
|
||||
|
||||
class ComputeQueryChunkAttn(Protocol):
|
||||
@staticmethod
|
||||
def __call__(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
) -> Tensor: ...
|
||||
|
||||
def _summarize_chunk(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> AttnChunk:
|
||||
attn_weights = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
|
||||
max_score = max_score.detach()
|
||||
exp_weights = torch.exp(attn_weights - max_score)
|
||||
exp_values = torch.bmm(exp_weights, value)
|
||||
max_score = max_score.squeeze(-1)
|
||||
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
|
||||
|
||||
def _query_chunk_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
summarize_chunk: SummarizeChunk,
|
||||
kv_chunk_size: int,
|
||||
) -> Tensor:
|
||||
batch_x_heads, k_tokens, k_channels_per_head = key.shape
|
||||
_, _, v_channels_per_head = value.shape
|
||||
|
||||
def chunk_scanner(chunk_idx: int) -> AttnChunk:
|
||||
key_chunk = narrow_trunc(
|
||||
key,
|
||||
1,
|
||||
chunk_idx,
|
||||
kv_chunk_size
|
||||
)
|
||||
value_chunk = narrow_trunc(
|
||||
value,
|
||||
1,
|
||||
chunk_idx,
|
||||
kv_chunk_size
|
||||
)
|
||||
return summarize_chunk(query, key_chunk, value_chunk)
|
||||
|
||||
chunks: List[AttnChunk] = [
|
||||
chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size)
|
||||
]
|
||||
acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks)))
|
||||
chunk_values, chunk_weights, chunk_max = acc_chunk
|
||||
|
||||
global_max, _ = torch.max(chunk_max, 0, keepdim=True)
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= torch.unsqueeze(max_diffs, -1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = chunk_values.sum(dim=0)
|
||||
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
|
||||
return all_values / all_weights
|
||||
|
||||
# TODO: refactor CrossAttention#get_attention_scores to share code with this
|
||||
def _get_attention_scores_no_kv_chunking(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
scale: float,
|
||||
) -> Tensor:
|
||||
attn_scores = torch.baddbmm(
|
||||
torch.empty(1, 1, 1, device=query.device, dtype=query.dtype),
|
||||
query,
|
||||
key.transpose(1,2),
|
||||
alpha=scale,
|
||||
beta=0,
|
||||
)
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
hidden_states_slice = torch.bmm(attn_probs, value)
|
||||
return hidden_states_slice
|
||||
|
||||
class ScannedChunk(NamedTuple):
|
||||
chunk_idx: int
|
||||
attn_chunk: AttnChunk
|
||||
|
||||
def efficient_dot_product_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
query_chunk_size=1024,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size_min: Optional[int] = None,
|
||||
use_checkpoint=True,
|
||||
):
|
||||
"""Computes efficient dot-product attention given query, key, and value.
|
||||
This is efficient version of attention presented in
|
||||
https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements.
|
||||
Args:
|
||||
query: queries for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
key: keys for calculating attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
value: values to be used in attention with shape of
|
||||
`[batch * num_heads, tokens, channels_per_head]`.
|
||||
query_chunk_size: int: query chunks size
|
||||
kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens)
|
||||
kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done).
|
||||
use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference)
|
||||
Returns:
|
||||
Output of shape `[batch * num_heads, query_tokens, channels_per_head]`.
|
||||
"""
|
||||
batch_x_heads, q_tokens, q_channels_per_head = query.shape
|
||||
_, k_tokens, _ = key.shape
|
||||
scale = q_channels_per_head ** -0.5
|
||||
|
||||
kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens)
|
||||
if kv_chunk_size_min is not None:
|
||||
kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)
|
||||
|
||||
def get_query_chunk(chunk_idx: int) -> Tensor:
|
||||
return narrow_trunc(
|
||||
query,
|
||||
1,
|
||||
chunk_idx,
|
||||
min(query_chunk_size, q_tokens)
|
||||
)
|
||||
|
||||
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
|
||||
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
|
||||
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
|
||||
_get_attention_scores_no_kv_chunking,
|
||||
scale=scale
|
||||
) if k_tokens <= kv_chunk_size else (
|
||||
# fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw)
|
||||
partial(
|
||||
_query_chunk_attention,
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
summarize_chunk=summarize_chunk,
|
||||
)
|
||||
)
|
||||
|
||||
if q_tokens <= query_chunk_size:
|
||||
# fast-path for when there's just 1 query chunk
|
||||
return compute_query_chunk_attn(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
compute_query_chunk_attn(
|
||||
query=get_query_chunk(i * query_chunk_size),
|
||||
key=key,
|
||||
value=value,
|
||||
) for i in range(math.ceil(q_tokens / query_chunk_size))
|
||||
], dim=1)
|
||||
return res
|
@ -66,17 +66,41 @@ class Embedding:
|
||||
return self.cached_checksum
|
||||
|
||||
|
||||
class DirWithTextualInversionEmbeddings:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.mtime = None
|
||||
|
||||
def has_changed(self):
|
||||
if not os.path.isdir(self.path):
|
||||
return False
|
||||
|
||||
mt = os.path.getmtime(self.path)
|
||||
if self.mtime is None or mt > self.mtime:
|
||||
return True
|
||||
|
||||
def update(self):
|
||||
if not os.path.isdir(self.path):
|
||||
return
|
||||
|
||||
self.mtime = os.path.getmtime(self.path)
|
||||
|
||||
|
||||
class EmbeddingDatabase:
|
||||
def __init__(self, embeddings_dir):
|
||||
def __init__(self):
|
||||
self.ids_lookup = {}
|
||||
self.word_embeddings = {}
|
||||
self.skipped_embeddings = {}
|
||||
self.dir_mtime = None
|
||||
self.embeddings_dir = embeddings_dir
|
||||
self.expected_shape = -1
|
||||
self.embedding_dirs = {}
|
||||
|
||||
def add_embedding_dir(self, path):
|
||||
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
|
||||
|
||||
def clear_embedding_dirs(self):
|
||||
self.embedding_dirs.clear()
|
||||
|
||||
def register_embedding(self, embedding, model):
|
||||
|
||||
self.word_embeddings[embedding.name] = embedding
|
||||
|
||||
ids = model.cond_stage_model.tokenize([embedding.name])[0]
|
||||
@ -93,65 +117,62 @@ class EmbeddingDatabase:
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload = False):
|
||||
mt = os.path.getmtime(self.embeddings_dir)
|
||||
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
|
||||
return
|
||||
def load_from_file(self, path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
self.dir_mtime = mt
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
def process_file(path, filename):
|
||||
name, ext = os.path.splitext(filename)
|
||||
ext = ext.upper()
|
||||
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
|
||||
_, second_ext = os.path.splitext(name)
|
||||
if second_ext.upper() == '.PREVIEW':
|
||||
return
|
||||
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
embed_image = Image.open(path)
|
||||
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
|
||||
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
|
||||
name = data.get('name', name)
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
data = extract_image_data_embed(embed_image)
|
||||
name = data.get('name', name)
|
||||
elif ext in ['.BIN', '.PT']:
|
||||
data = torch.load(path, map_location="cpu")
|
||||
else:
|
||||
return
|
||||
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
# textual inversion embeddings
|
||||
if 'string_to_param' in data:
|
||||
param_dict = data['string_to_param']
|
||||
if hasattr(param_dict, '_parameters'):
|
||||
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
||||
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
||||
emb = next(iter(param_dict.items()))[1]
|
||||
# diffuser concepts
|
||||
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
||||
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings[name] = embedding
|
||||
emb = next(iter(data.values()))
|
||||
if len(emb.shape) == 1:
|
||||
emb = emb.unsqueeze(0)
|
||||
else:
|
||||
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
||||
|
||||
for root, dirs, fns in os.walk(self.embeddings_dir):
|
||||
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
||||
embedding = Embedding(vec, name)
|
||||
embedding.step = data.get('step', None)
|
||||
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
|
||||
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
||||
embedding.vectors = vec.shape[0]
|
||||
embedding.shape = vec.shape[-1]
|
||||
|
||||
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
|
||||
self.register_embedding(embedding, shared.sd_model)
|
||||
else:
|
||||
self.skipped_embeddings[name] = embedding
|
||||
|
||||
def load_from_dir(self, embdir):
|
||||
if not os.path.isdir(embdir.path):
|
||||
return
|
||||
|
||||
for root, dirs, fns in os.walk(embdir.path):
|
||||
for fn in fns:
|
||||
try:
|
||||
fullfn = os.path.join(root, fn)
|
||||
@ -159,12 +180,32 @@ class EmbeddingDatabase:
|
||||
if os.stat(fullfn).st_size == 0:
|
||||
continue
|
||||
|
||||
process_file(fullfn, fn)
|
||||
self.load_from_file(fullfn, fn)
|
||||
except Exception:
|
||||
print(f"Error loading embedding {fn}:", file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
continue
|
||||
|
||||
def load_textual_inversion_embeddings(self, force_reload=False):
|
||||
if not force_reload:
|
||||
need_reload = False
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
if embdir.has_changed():
|
||||
need_reload = True
|
||||
break
|
||||
|
||||
if not need_reload:
|
||||
return
|
||||
|
||||
self.ids_lookup.clear()
|
||||
self.word_embeddings.clear()
|
||||
self.skipped_embeddings.clear()
|
||||
self.expected_shape = self.get_expected_shape()
|
||||
|
||||
for path, embdir in self.embedding_dirs.items():
|
||||
self.load_from_dir(embdir)
|
||||
embdir.update()
|
||||
|
||||
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
|
||||
if len(self.skipped_embeddings) > 0:
|
||||
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
|
||||
@ -247,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
|
||||
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
|
||||
assert steps, "Max steps is empty or 0"
|
||||
assert isinstance(steps, int), "Max steps must be integer"
|
||||
assert steps > 0 , "Max steps must be positive"
|
||||
assert steps > 0, "Max steps must be positive"
|
||||
assert isinstance(save_model_every, int), "Save {name} must be integer"
|
||||
assert save_model_every >= 0 , "Save {name} must be positive or 0"
|
||||
assert save_model_every >= 0, "Save {name} must be positive or 0"
|
||||
assert isinstance(create_image_every, int), "Create image must be integer"
|
||||
assert create_image_every >= 0 , "Create image must be positive or 0"
|
||||
assert create_image_every >= 0, "Create image must be positive or 0"
|
||||
if save_model_every or create_image_every:
|
||||
assert log_directory, "Log directory is empty"
|
||||
|
||||
|
||||
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
|
||||
save_embedding_every = save_embedding_every or 0
|
||||
create_image_every = create_image_every or 0
|
||||
|
@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
|
||||
with devices.autocast():
|
||||
p.init([""], [0], [0])
|
||||
|
||||
return f"resize to: <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
|
||||
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{p.hr_upscale_to_x}x{p.hr_upscale_to_y}</span>"
|
||||
|
||||
|
||||
def apply_styles(prompt, prompt_neg, style1_name, style2_name):
|
||||
|
@ -30,4 +30,4 @@ inflection
|
||||
GitPython
|
||||
torchsde
|
||||
safetensors
|
||||
psutil; sys_platform == 'darwin'
|
||||
psutil
|
||||
|
BIN
screenshot.png
BIN
screenshot.png
Binary file not shown.
Before Width: | Height: | Size: 513 KiB After Width: | Height: | Size: 411 KiB |
@ -25,6 +25,8 @@ class Script(scripts.Script):
|
||||
return [info, overlap, upscaler_index, scale_factor]
|
||||
|
||||
def run(self, p, _, overlap, upscaler_index, scale_factor):
|
||||
if isinstance(upscaler_index, str):
|
||||
upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower())
|
||||
processing.fix_seed(p)
|
||||
upscaler = shared.sd_upscalers[upscaler_index]
|
||||
|
||||
|
13
style.css
13
style.css
@ -512,7 +512,7 @@ input[type="range"]{
|
||||
border: none;
|
||||
background: none;
|
||||
flex: unset;
|
||||
gap: 0.5em;
|
||||
gap: 1em;
|
||||
}
|
||||
|
||||
#quicksettings > div > div{
|
||||
@ -521,6 +521,17 @@ input[type="range"]{
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
#quicksettings > div > div > div > div > label > span {
|
||||
position: relative;
|
||||
margin-right: 9em;
|
||||
margin-bottom: -1em;
|
||||
}
|
||||
|
||||
#quicksettings > div > div > label > span {
|
||||
position: relative;
|
||||
margin-bottom: -1em;
|
||||
}
|
||||
|
||||
canvas[key="mask"] {
|
||||
z-index: 12 !important;
|
||||
filter: invert();
|
||||
|
@ -50,6 +50,12 @@ class TestImg2ImgWorking(unittest.TestCase):
|
||||
self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png"))
|
||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||
|
||||
def test_img2img_sd_upscale_performed(self):
|
||||
self.simple_img2img["script_name"] = "sd upscale"
|
||||
self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0]
|
||||
|
||||
self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 329 KiB |
Loading…
Reference in New Issue
Block a user