Merge branch 'dev' into tooltip
This commit is contained in:
commit
73d956454f
3
.gitignore
vendored
3
.gitignore
vendored
@ -32,4 +32,5 @@ notification.mp3
|
|||||||
/extensions
|
/extensions
|
||||||
/test/stdout.txt
|
/test/stdout.txt
|
||||||
/test/stderr.txt
|
/test/stderr.txt
|
||||||
/cache.json
|
/cache.json*
|
||||||
|
/config_states/
|
||||||
|
62
CHANGELOG.md
Normal file
62
CHANGELOG.md
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
## 1.1.1
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
||||||
|
|
||||||
|
## 1.1.0
|
||||||
|
### Features:
|
||||||
|
* switch to torch 2.0.0 (except for AMD GPUs)
|
||||||
|
* visual improvements to custom code scripts
|
||||||
|
* add filename patterns: [clip_skip], [hasprompt<>], [batch_number], [generation_number]
|
||||||
|
* add support for saving init images in img2img, and record their hashes in infotext for reproducability
|
||||||
|
* automatically select current word when adjusting weight with ctrl+up/down
|
||||||
|
* add dropdowns for X/Y/Z plot
|
||||||
|
* setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
|
||||||
|
* support Gradio's theme API
|
||||||
|
* use TCMalloc on Linux by default; possible fix for memory leaks
|
||||||
|
* (optimization) option to remove negative conditioning at low sigma values #9177
|
||||||
|
* embed model merge metadata in .safetensors file
|
||||||
|
* extension settings backup/restore feature #9169
|
||||||
|
* add "resize by" and "resize to" tabs to img2img
|
||||||
|
* add option "keep original size" to textual inversion images preprocess
|
||||||
|
* image viewer scrolling via analog stick
|
||||||
|
* button to restore the progress from session lost / tab reload
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* gradio bumped to 3.28.1
|
||||||
|
* in extra tab, change extras "scale to" to sliders
|
||||||
|
* add labels to tool buttons to make it possible to hide them
|
||||||
|
* add tiled inference support for ScuNET
|
||||||
|
* add branch support for extension installation
|
||||||
|
* change linux installation script to insall into current directory rather than /home/username
|
||||||
|
* sort textual inversion embeddings by name (case insensitive)
|
||||||
|
* allow styles.csv to be symlinked or mounted in docker
|
||||||
|
* remove the "do not add watermark to images" option
|
||||||
|
* make selected tab configurable with UI config
|
||||||
|
* extra networks UI in now fixed height and scrollable
|
||||||
|
* add disable_tls_verify arg for use with self-signed certs
|
||||||
|
|
||||||
|
### Extensions:
|
||||||
|
* Add reload callback
|
||||||
|
* add is_hr_pass field for processing
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix broken batch image processing on 'Extras/Batch Process' tab
|
||||||
|
* add "None" option to extra networks dropdowns
|
||||||
|
* fix FileExistsError for CLIP Interrogator
|
||||||
|
* fix /sdapi/v1/txt2img endpoint not working on Linux #9319
|
||||||
|
* fix disappearing live previews and progressbar during slow tasks
|
||||||
|
* fix fullscreen image view not working properly in some cases
|
||||||
|
* prevent alwayson_scripts args param resizing script_arg list when they are inserted in it
|
||||||
|
* fix prompt schedule for second order samplers
|
||||||
|
* fix image mask/composite for weird resolutions #9628
|
||||||
|
* use correct images for previews when using AND (see #9491)
|
||||||
|
* one broken image in img2img batch won't stop all processing
|
||||||
|
* fix image orientation bug in train/preprocess
|
||||||
|
* fix Ngrok recreating tunnels every reload
|
||||||
|
* fix --realesrgan-models-path and --ldsr-models-path not working
|
||||||
|
* fix --skip-install not working
|
||||||
|
* outpainting Mk2 & Poorman should use the SAMPLE file format to save images, not GRID file format
|
||||||
|
* do not fail all Loras if some have failed to load when making a picture
|
||||||
|
|
||||||
|
## 1.0.0
|
||||||
|
* everything
|
@ -100,7 +100,7 @@ Alternatively, use online services (like Google Colab):
|
|||||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||||
|
|
||||||
### Automatic Installation on Windows
|
### Automatic Installation on Windows
|
||||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
|
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
|
||||||
2. Install [git](https://git-scm.com/download/win).
|
2. Install [git](https://git-scm.com/download/win).
|
||||||
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
||||||
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||||
@ -115,11 +115,12 @@ sudo dnf install wget git python3
|
|||||||
# Arch-based:
|
# Arch-based:
|
||||||
sudo pacman -S wget git python3
|
sudo pacman -S wget git python3
|
||||||
```
|
```
|
||||||
2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
|
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
||||||
```bash
|
```bash
|
||||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
||||||
```
|
```
|
||||||
3. Run `webui.sh`.
|
3. Run `webui.sh`.
|
||||||
|
4. Check `webui-user.sh` for options.
|
||||||
### Installation on Apple Silicon
|
### Installation on Apple Silicon
|
||||||
|
|
||||||
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
||||||
@ -158,4 +159,4 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -4,8 +4,8 @@ channels:
|
|||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.10
|
- python=3.10
|
||||||
- pip=22.2.2
|
- pip=23.0
|
||||||
- cudatoolkit=11.3
|
- cudatoolkit=11.8
|
||||||
- pytorch=1.12.1
|
- pytorch=2.0
|
||||||
- torchvision=0.13.1
|
- torchvision=0.15
|
||||||
- numpy=1.23.1
|
- numpy=1.23
|
||||||
|
@ -25,22 +25,28 @@ class UpscalerLDSR(Upscaler):
|
|||||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||||
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
|
|
||||||
|
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
|
||||||
|
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
|
||||||
|
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
|
||||||
|
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
|
||||||
|
|
||||||
if os.path.exists(yaml_path):
|
if os.path.exists(yaml_path):
|
||||||
statinfo = os.stat(yaml_path)
|
statinfo = os.stat(yaml_path)
|
||||||
if statinfo.st_size >= 10485760:
|
if statinfo.st_size >= 10485760:
|
||||||
print("Removing invalid LDSR YAML file.")
|
print("Removing invalid LDSR YAML file.")
|
||||||
os.remove(yaml_path)
|
os.remove(yaml_path)
|
||||||
|
|
||||||
if os.path.exists(old_model_path):
|
if os.path.exists(old_model_path):
|
||||||
print("Renaming model from model.pth to model.ckpt")
|
print("Renaming model from model.pth to model.ckpt")
|
||||||
os.rename(old_model_path, new_model_path)
|
os.rename(old_model_path, new_model_path)
|
||||||
if os.path.exists(safetensors_model_path):
|
|
||||||
model = safetensors_model_path
|
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||||
|
model = local_safetensors_path
|
||||||
else:
|
else:
|
||||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
|
||||||
file_name="model.ckpt", progress=True)
|
|
||||||
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
|
||||||
file_name="project.yaml", progress=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return LDSR(model, yaml)
|
return LDSR(model, yaml)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors
|
from modules import shared, devices, sd_models, errors, scripts
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
@ -93,6 +93,7 @@ class LoraOnDisk:
|
|||||||
self.metadata = m
|
self.metadata = m
|
||||||
|
|
||||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
class LoraModule:
|
||||||
@ -165,8 +166,10 @@ def load_lora(name, filename):
|
|||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
@ -199,11 +202,11 @@ def load_loras(names, multipliers=None):
|
|||||||
|
|
||||||
loaded_loras.clear()
|
loaded_loras.clear()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
if any([x is None for x in loras_on_disk]):
|
if any([x is None for x in loras_on_disk]):
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
lora = already_loaded.get(name, None)
|
lora = already_loaded.get(name, None)
|
||||||
@ -211,7 +214,11 @@ def load_loras(names, multipliers=None):
|
|||||||
lora_on_disk = loras_on_disk[i]
|
lora_on_disk = loras_on_disk[i]
|
||||||
if lora_on_disk is not None:
|
if lora_on_disk is not None:
|
||||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||||
lora = load_lora(name, lora_on_disk.filename)
|
try:
|
||||||
|
lora = load_lora(name, lora_on_disk.filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"loading Lora {lora_on_disk.filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
if lora is None:
|
if lora is None:
|
||||||
print(f"Couldn't find Lora with name {name}")
|
print(f"Couldn't find Lora with name {name}")
|
||||||
@ -228,6 +235,8 @@ def lora_calc_updown(lora, module, target):
|
|||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
@ -339,6 +348,7 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
available_lora_aliases.clear()
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
@ -352,11 +362,50 @@ def list_available_loras():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
entry = LoraOnDisk(name, filename)
|
||||||
|
|
||||||
available_loras[name] = LoraOnDisk(name, filename)
|
available_loras[name] = entry
|
||||||
|
|
||||||
|
available_lora_aliases[name] = entry
|
||||||
|
available_lora_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted(infotext, params):
|
||||||
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
|
added = []
|
||||||
|
|
||||||
|
for k, v in params.items():
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re_lora_name.match(name)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
|
|
||||||
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
|
|
||||||
available_loras = {}
|
available_loras = {}
|
||||||
|
available_lora_aliases = {}
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
|
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
@ -49,8 +49,9 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
|
|||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
script_callbacks.on_before_ui(before_ui)
|
script_callbacks.on_before_ui(before_ui)
|
||||||
|
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
}))
|
}))
|
||||||
|
@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,15 @@ import traceback
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
|
from modules.shared import opts
|
||||||
|
from modules import images
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -42,28 +46,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers.append(scaler_data2)
|
scalers.append(scaler_data2)
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image, selected_file):
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def tiled_inference(img, model):
|
||||||
|
# test the image tile by tile
|
||||||
|
h, w = img.shape[2:]
|
||||||
|
tile = opts.SCUNET_tile
|
||||||
|
tile_overlap = opts.SCUNET_tile_overlap
|
||||||
|
if tile == 0:
|
||||||
|
return model(img)
|
||||||
|
|
||||||
|
device = devices.get_device_for('scunet')
|
||||||
|
assert tile % 8 == 0, "tile size should be a multiple of window_size"
|
||||||
|
sf = 1
|
||||||
|
|
||||||
|
stride = tile - tile_overlap
|
||||||
|
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||||
|
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||||
|
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
|
||||||
|
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
|
||||||
|
|
||||||
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
|
||||||
|
for h_idx in h_idx_list:
|
||||||
|
|
||||||
|
for w_idx in w_idx_list:
|
||||||
|
|
||||||
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
|
|
||||||
|
out_patch = model(in_patch)
|
||||||
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
||||||
|
E[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch)
|
||||||
|
W[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch_mask)
|
||||||
|
pbar.update(1)
|
||||||
|
output = E.div_(W)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
img = np.array(img)
|
tile = opts.SCUNET_tile
|
||||||
img = img[:, :, ::-1]
|
h, w = img.height, img.width
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
np_img = np.array(img)
|
||||||
img = torch.from_numpy(img).float()
|
np_img = np_img[:, :, ::-1] # RGB to BGR
|
||||||
img = img.unsqueeze(0).to(device)
|
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
||||||
|
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
||||||
|
|
||||||
with torch.no_grad():
|
if tile > h or tile > w:
|
||||||
output = model(img)
|
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
_img[:, :, :h, :w] = torch_img # pad image
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
torch_img = _img
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
|
||||||
|
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
||||||
|
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
||||||
|
del torch_img, torch_output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return PIL.Image.fromarray(output, 'RGB')
|
|
||||||
|
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
||||||
|
output = output[:, :, ::-1] # BGR to RGB
|
||||||
|
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
@ -84,4 +138,3 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -1,103 +1,42 @@
|
|||||||
// Stable Diffusion WebUI - Bracket checker
|
// Stable Diffusion WebUI - Bracket checker
|
||||||
// Version 1.0
|
// By Hingashi no Florin/Bwin4L & @akx
|
||||||
// By Hingashi no Florin/Bwin4L
|
|
||||||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||||
|
|
||||||
function checkBrackets(evt, textArea, counterElt) {
|
function checkBrackets(textArea, counterElt) {
|
||||||
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
|
var counts = {};
|
||||||
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
|
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
|
||||||
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
|
counts[bracket] = (counts[bracket] || 0) + 1;
|
||||||
|
});
|
||||||
|
var errors = [];
|
||||||
|
|
||||||
openBracketRegExp = /\(/g;
|
function checkPair(open, close, kind) {
|
||||||
closeBracketRegExp = /\)/g;
|
if (counts[open] !== counts[close]) {
|
||||||
|
errors.push(
|
||||||
openSquareBracketRegExp = /\[/g;
|
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
|
||||||
closeSquareBracketRegExp = /\]/g;
|
);
|
||||||
|
|
||||||
openCurlyBracketRegExp = /\{/g;
|
|
||||||
closeCurlyBracketRegExp = /\}/g;
|
|
||||||
|
|
||||||
totalOpenBracketMatches = 0;
|
|
||||||
totalCloseBracketMatches = 0;
|
|
||||||
totalOpenSquareBracketMatches = 0;
|
|
||||||
totalCloseSquareBracketMatches = 0;
|
|
||||||
totalOpenCurlyBracketMatches = 0;
|
|
||||||
totalCloseCurlyBracketMatches = 0;
|
|
||||||
|
|
||||||
openBracketMatches = textArea.value.match(openBracketRegExp);
|
|
||||||
if(openBracketMatches) {
|
|
||||||
totalOpenBracketMatches = openBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
|
||||||
if(closeBracketMatches) {
|
|
||||||
totalCloseBracketMatches = closeBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
|
||||||
if(openSquareBracketMatches) {
|
|
||||||
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
|
||||||
if(closeSquareBracketMatches) {
|
|
||||||
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
|
||||||
if(openCurlyBracketMatches) {
|
|
||||||
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
|
||||||
if(closeCurlyBracketMatches) {
|
|
||||||
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
|
||||||
if(!counterElt.title.includes(errorStringParen)) {
|
|
||||||
counterElt.title += errorStringParen;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
checkPair('(', ')', 'round brackets');
|
||||||
if(!counterElt.title.includes(errorStringSquare)) {
|
checkPair('[', ']', 'square brackets');
|
||||||
counterElt.title += errorStringSquare;
|
checkPair('{', '}', 'curly brackets');
|
||||||
}
|
counterElt.title = errors.join('\n');
|
||||||
} else {
|
counterElt.classList.toggle('error', errors.length !== 0);
|
||||||
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
function setupBracketChecking(id_prompt, id_counter) {
|
||||||
if(!counterElt.title.includes(errorStringCurly)) {
|
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||||
counterElt.title += errorStringCurly;
|
var counter = gradioApp().getElementById(id_counter)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
|
||||||
}
|
|
||||||
|
|
||||||
if(counterElt.title != '') {
|
if (textarea && counter) {
|
||||||
counterElt.classList.add('error');
|
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
|
||||||
} else {
|
|
||||||
counterElt.classList.remove('error');
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function setupBracketChecking(id_prompt, id_counter){
|
onUiLoaded(function () {
|
||||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
|
||||||
var counter = gradioApp().getElementById(id_counter)
|
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
|
||||||
|
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
|
||||||
textarea.addEventListener("input", function(evt){
|
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
|
||||||
checkBrackets(evt, textarea, counter)
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
onUiLoaded(function(){
|
|
||||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
|
||||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
|
||||||
setupBracketChecking('img2img_prompt', 'img2img_token_counter')
|
|
||||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
|
||||||
})
|
|
||||||
|
@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
|
|||||||
|
|
||||||
var viewportOffset = targetElement.getBoundingClientRect();
|
var viewportOffset = targetElement.getBoundingClientRect();
|
||||||
|
|
||||||
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
||||||
|
|
||||||
scaledx = targetElement.naturalWidth*viewportscale
|
var scaledx = targetElement.naturalWidth*viewportscale
|
||||||
scaledy = targetElement.naturalHeight*viewportscale
|
var scaledy = targetElement.naturalHeight*viewportscale
|
||||||
|
|
||||||
cleintRectTop = (viewportOffset.top+window.scrollY)
|
var cleintRectTop = (viewportOffset.top+window.scrollY)
|
||||||
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
var cleintRectLeft = (viewportOffset.left+window.scrollX)
|
||||||
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
||||||
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
||||||
|
|
||||||
viewRectTop = cleintRectCentreY-(scaledy/2)
|
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
|
||||||
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
var arscaledx = currentWidth*arscale
|
||||||
arRectWidth = scaledx
|
var arscaledy = currentHeight*arscale
|
||||||
arRectHeight = scaledy
|
|
||||||
|
|
||||||
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
var arRectTop = cleintRectCentreY-(arscaledy/2)
|
||||||
arscaledx = currentWidth*arscale
|
var arRectLeft = cleintRectCentreX-(arscaledx/2)
|
||||||
arscaledy = currentHeight*arscale
|
var arRectWidth = arscaledx
|
||||||
|
var arRectHeight = arscaledy
|
||||||
arRectTop = cleintRectCentreY-(arscaledy/2)
|
|
||||||
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
|
||||||
arRectWidth = arscaledx
|
|
||||||
arRectHeight = arscaledy
|
|
||||||
|
|
||||||
arPreviewRect.style.top = arRectTop+'px';
|
arPreviewRect.style.top = arRectTop+'px';
|
||||||
arPreviewRect.style.left = arRectLeft+'px';
|
arPreviewRect.style.left = arRectLeft+'px';
|
||||||
|
@ -4,7 +4,7 @@ contextMenuInit = function(){
|
|||||||
let menuSpecs = new Map();
|
let menuSpecs = new Map();
|
||||||
|
|
||||||
const uid = function(){
|
const uid = function(){
|
||||||
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
return Date.now().toString(36) + Math.random().toString(36).substring(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
function showContextMenu(event,element,menuEntries){
|
function showContextMenu(event,element,menuEntries){
|
||||||
@ -16,8 +16,7 @@ contextMenuInit = function(){
|
|||||||
oldMenu.remove()
|
oldMenu.remove()
|
||||||
}
|
}
|
||||||
|
|
||||||
let tabButton = uiCurrentTab
|
let baseStyle = window.getComputedStyle(uiCurrentTab)
|
||||||
let baseStyle = window.getComputedStyle(tabButton)
|
|
||||||
|
|
||||||
const contextMenu = document.createElement('nav')
|
const contextMenu = document.createElement('nav')
|
||||||
contextMenu.id = "context-menu"
|
contextMenu.id = "context-menu"
|
||||||
@ -36,7 +35,7 @@ contextMenuInit = function(){
|
|||||||
menuEntries.forEach(function(entry){
|
menuEntries.forEach(function(entry){
|
||||||
let contextMenuEntry = document.createElement('a')
|
let contextMenuEntry = document.createElement('a')
|
||||||
contextMenuEntry.innerHTML = entry['name']
|
contextMenuEntry.innerHTML = entry['name']
|
||||||
contextMenuEntry.addEventListener("click", function(e) {
|
contextMenuEntry.addEventListener("click", function() {
|
||||||
entry['func']();
|
entry['func']();
|
||||||
})
|
})
|
||||||
contextMenuList.append(contextMenuEntry);
|
contextMenuList.append(contextMenuEntry);
|
||||||
@ -63,7 +62,7 @@ contextMenuInit = function(){
|
|||||||
|
|
||||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||||
|
|
||||||
currentItems = menuSpecs.get(targetElementSelector)
|
var currentItems = menuSpecs.get(targetElementSelector)
|
||||||
|
|
||||||
if(!currentItems){
|
if(!currentItems){
|
||||||
currentItems = []
|
currentItems = []
|
||||||
@ -79,7 +78,7 @@ contextMenuInit = function(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function removeContextMenuOption(uid){
|
function removeContextMenuOption(uid){
|
||||||
menuSpecs.forEach(function(v,k) {
|
menuSpecs.forEach(function(v) {
|
||||||
let index = -1
|
let index = -1
|
||||||
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
||||||
if(index>=0){
|
if(index>=0){
|
||||||
@ -112,7 +111,6 @@ contextMenuInit = function(){
|
|||||||
if(e.composedPath()[0].matches(k)){
|
if(e.composedPath()[0].matches(k)){
|
||||||
showContextMenu(e,e.composedPath()[0],v)
|
showContextMenu(e,e.composedPath()[0],v)
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -161,14 +159,6 @@ addContextMenuEventListener = initResponse[2];
|
|||||||
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||||
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||||
|
|
||||||
appendContextMenuOption('#roll','Roll three',
|
|
||||||
function(){
|
|
||||||
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
|
||||||
setTimeout(function(){rollbutton.click()},100)
|
|
||||||
setTimeout(function(){rollbutton.click()},200)
|
|
||||||
setTimeout(function(){rollbutton.click()},300)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
})();
|
})();
|
||||||
//End example Context Menu Items
|
//End example Context Menu Items
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ function keyupEditAttention(event){
|
|||||||
// Find opening parenthesis around current cursor
|
// Find opening parenthesis around current cursor
|
||||||
const before = text.substring(0, selectionStart);
|
const before = text.substring(0, selectionStart);
|
||||||
let beforeParen = before.lastIndexOf(OPEN);
|
let beforeParen = before.lastIndexOf(OPEN);
|
||||||
if (beforeParen == -1) return false;
|
if (beforeParen == -1) return false;
|
||||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||||
@ -27,7 +27,7 @@ function keyupEditAttention(event){
|
|||||||
// Find closing parenthesis around current cursor
|
// Find closing parenthesis around current cursor
|
||||||
const after = text.substring(selectionStart);
|
const after = text.substring(selectionStart);
|
||||||
let afterParen = after.indexOf(CLOSE);
|
let afterParen = after.indexOf(CLOSE);
|
||||||
if (afterParen == -1) return false;
|
if (afterParen == -1) return false;
|
||||||
let afterParenOpen = after.indexOf(OPEN);
|
let afterParenOpen = after.indexOf(OPEN);
|
||||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||||
@ -43,16 +43,34 @@ function keyupEditAttention(event){
|
|||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function selectCurrentWord(){
|
||||||
|
if (selectionStart !== selectionEnd) return false;
|
||||||
|
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
||||||
|
|
||||||
|
// seek backward until to find beggining
|
||||||
|
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
||||||
|
selectionStart--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// seek forward to find end
|
||||||
|
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
|
||||||
|
selectionEnd++;
|
||||||
|
}
|
||||||
|
|
||||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
if(! selectCurrentParenthesisBlock('<', '>')){
|
return true;
|
||||||
selectCurrentParenthesisBlock('(', ')')
|
}
|
||||||
|
|
||||||
|
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||||
|
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
||||||
|
selectCurrentWord();
|
||||||
}
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
closeCharacter = ')'
|
var closeCharacter = ')'
|
||||||
delta = opts.keyedit_precision_attention
|
var delta = opts.keyedit_precision_attention
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
||||||
closeCharacter = '>'
|
closeCharacter = '>'
|
||||||
@ -73,15 +91,21 @@ function keyupEditAttention(event){
|
|||||||
selectionEnd += 1;
|
selectionEnd += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
weight = parseFloat(weight.toPrecision(12));
|
weight = parseFloat(weight.toPrecision(12));
|
||||||
if(String(weight).length == 1) weight += ".0"
|
if(String(weight).length == 1) weight += ".0"
|
||||||
|
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
|
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
|
||||||
|
selectionStart--;
|
||||||
|
selectionEnd--;
|
||||||
|
} else {
|
||||||
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
||||||
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
target.value = text;
|
target.value = text;
|
||||||
@ -93,4 +117,4 @@ function keyupEditAttention(event){
|
|||||||
|
|
||||||
addEventListener('keydown', (event) => {
|
addEventListener('keydown', (event) => {
|
||||||
keyupEditAttention(event);
|
keyupEditAttention(event);
|
||||||
});
|
});
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
|
|
||||||
function extensions_apply(_, _, disable_all){
|
function extensions_apply(_disabled_list, _update_list, disable_all){
|
||||||
var disable = []
|
var disable = []
|
||||||
var update = []
|
var update = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
|
|
||||||
if(x.name.startsWith("update_") && x.checked)
|
if(x.name.startsWith("update_") && x.checked)
|
||||||
update.push(x.name.substr(7))
|
update.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
restart_reload()
|
restart_reload()
|
||||||
@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
|
|||||||
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
||||||
}
|
}
|
||||||
|
|
||||||
function extensions_check(_, _){
|
function extensions_check(){
|
||||||
var disable = []
|
var disable = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
@ -41,9 +41,31 @@ function install_extension_from_index(button, url){
|
|||||||
button.disabled = "disabled"
|
button.disabled = "disabled"
|
||||||
button.value = "Installing..."
|
button.value = "Installing..."
|
||||||
|
|
||||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
var textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||||
textarea.value = url
|
textarea.value = url
|
||||||
updateInput(textarea)
|
updateInput(textarea)
|
||||||
|
|
||||||
gradioApp().querySelector('#install_extension_button').click()
|
gradioApp().querySelector('#install_extension_button').click()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
|
||||||
|
if (config_state_name == "Current") {
|
||||||
|
return [false, config_state_name, config_restore_type];
|
||||||
|
}
|
||||||
|
let restored = "";
|
||||||
|
if (config_restore_type == "extensions") {
|
||||||
|
restored = "all saved extension versions";
|
||||||
|
} else if (config_restore_type == "webui") {
|
||||||
|
restored = "the webui version";
|
||||||
|
} else {
|
||||||
|
restored = "the webui version and all saved extension versions";
|
||||||
|
}
|
||||||
|
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
|
||||||
|
if (confirmed) {
|
||||||
|
restart_reload();
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
|
x.innerHTML = "Loading..."
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return [confirmed, config_state_name, config_restore_type];
|
||||||
|
}
|
||||||
|
@ -10,11 +10,11 @@ function setupExtraNetworksForTab(tabname){
|
|||||||
tabs.appendChild(search)
|
tabs.appendChild(search)
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh)
|
||||||
|
|
||||||
search.addEventListener("input", function(evt){
|
search.addEventListener("input", function(){
|
||||||
searchTerm = search.value.toLowerCase()
|
var searchTerm = search.value.toLowerCase()
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -55,7 +55,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
|||||||
|
|
||||||
var partToSearch = m[1]
|
var partToSearch = m[1]
|
||||||
var replaced = false
|
var replaced = false
|
||||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
|
||||||
m = found.match(re_extranet);
|
m = found.match(re_extranet);
|
||||||
if(m[1] == partToSearch){
|
if(m[1] == partToSearch){
|
||||||
replaced = true;
|
replaced = true;
|
||||||
@ -96,9 +96,9 @@ function saveCardPreview(event, tabname, filename){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event){
|
function extraNetworksSearchButton(tabs_id, event){
|
||||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||||
button = event.target
|
var button = event.target
|
||||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||||
|
|
||||||
searchTextarea.value = text
|
searchTextarea.value = text
|
||||||
updateInput(searchTextarea)
|
updateInput(searchTextarea)
|
||||||
@ -133,7 +133,7 @@ function popup(contents){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksShowMetadata(text){
|
function extraNetworksShowMetadata(text){
|
||||||
elem = document.createElement('pre')
|
var elem = document.createElement('pre')
|
||||||
elem.classList.add('popup-metadata');
|
elem.classList.add('popup-metadata');
|
||||||
elem.textContent = text;
|
elem.textContent = text;
|
||||||
|
|
||||||
@ -165,7 +165,7 @@ function requestGet(url, data, handler, errorHandler){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
||||||
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
||||||
|
|
||||||
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
||||||
if(data && data.metadata){
|
if(data && data.metadata){
|
||||||
|
@ -16,14 +16,14 @@ onUiUpdate(function(){
|
|||||||
|
|
||||||
let modalObserver = new MutationObserver(function(mutations) {
|
let modalObserver = new MutationObserver(function(mutations) {
|
||||||
mutations.forEach(function(mutationRecord) {
|
mutations.forEach(function(mutationRecord) {
|
||||||
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
|
||||||
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
|
||||||
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
function attachGalleryListeners(tab_name) {
|
function attachGalleryListeners(tab_name) {
|
||||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||||
gallery?.addEventListener('keydown', (e) => {
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||||
|
@ -22,6 +22,7 @@ titles = {
|
|||||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
"\u{1f4d2}": "Paste available values into the field",
|
"\u{1f4d2}": "Paste available values into the field",
|
||||||
"\u{1f3b4}": "Show/hide extra networks",
|
"\u{1f3b4}": "Show/hide extra networks",
|
||||||
|
"\u{1f300}": "Restore progress",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
@ -65,8 +66,8 @@ titles = {
|
|||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||||
@ -85,7 +86,6 @@ titles = {
|
|||||||
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
||||||
|
|
||||||
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||||
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
|
||||||
|
|
||||||
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
||||||
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
|
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
|
||||||
@ -111,15 +111,18 @@ titles = {
|
|||||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
||||||
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
||||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
|
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
|
||||||
|
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
||||||
tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
if (span.title) return; // already has a title
|
||||||
|
|
||||||
if(!tooltip){
|
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
||||||
|
|
||||||
|
if(!tooltip){
|
||||||
tooltip = localization[titles[span.value]] || titles[span.value];
|
tooltip = localization[titles[span.value]] || titles[span.value];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
|
|
||||||
function setInactive(elem, inactive){
|
|
||||||
if(inactive){
|
|
||||||
elem.classList.add('inactive')
|
|
||||||
} else{
|
|
||||||
elem.classList.remove('inactive')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
function setInactive(elem, inactive){
|
||||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
elem.classList.toggle('inactive', !!inactive)
|
||||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
}
|
||||||
|
|
||||||
|
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||||
|
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||||
|
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||||
|
|
||||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
||||||
* @see https://github.com/gradio-app/gradio/issues/1721
|
* @see https://github.com/gradio-app/gradio/issues/1721
|
||||||
*/
|
*/
|
||||||
window.addEventListener( 'resize', () => imageMaskResize());
|
|
||||||
function imageMaskResize() {
|
function imageMaskResize() {
|
||||||
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
||||||
if ( ! canvases.length ) {
|
if ( ! canvases.length ) {
|
||||||
canvases_fixed = false;
|
canvases_fixed = false; // TODO: this is unused..?
|
||||||
window.removeEventListener( 'resize', imageMaskResize );
|
window.removeEventListener( 'resize', imageMaskResize );
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -15,7 +14,7 @@ function imageMaskResize() {
|
|||||||
const previewImage = wrapper.previousElementSibling;
|
const previewImage = wrapper.previousElementSibling;
|
||||||
|
|
||||||
if ( ! previewImage.complete ) {
|
if ( ! previewImage.complete ) {
|
||||||
previewImage.addEventListener( 'load', () => imageMaskResize());
|
previewImage.addEventListener( 'load', imageMaskResize);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +23,6 @@ function imageMaskResize() {
|
|||||||
const nw = previewImage.naturalWidth;
|
const nw = previewImage.naturalWidth;
|
||||||
const nh = previewImage.naturalHeight;
|
const nh = previewImage.naturalHeight;
|
||||||
const portrait = nh > nw;
|
const portrait = nh > nw;
|
||||||
const factor = portrait;
|
|
||||||
|
|
||||||
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
||||||
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
||||||
@ -40,6 +38,7 @@ function imageMaskResize() {
|
|||||||
c.style.maxHeight = '100%';
|
c.style.maxHeight = '100%';
|
||||||
c.style.objectFit = 'contain';
|
c.style.objectFit = 'contain';
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(() => imageMaskResize());
|
onUiUpdate(imageMaskResize);
|
||||||
|
window.addEventListener( 'resize', imageMaskResize);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
window.onload = (function(){
|
window.onload = (function(){
|
||||||
window.addEventListener('drop', e => {
|
window.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const idx = selected_gallery_index();
|
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (result != -1) {
|
if (result != -1) {
|
||||||
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
||||||
nextButton.click()
|
nextButton.click()
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomSet(modalImage, enable) {
|
function modalZoomSet(modalImage, enable) {
|
||||||
if (enable) {
|
if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
|
||||||
modalImage.classList.add('modalImageFullscreen');
|
|
||||||
} else {
|
|
||||||
modalImage.classList.remove('modalImageFullscreen');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomToggle(event) {
|
function modalZoomToggle(event) {
|
||||||
modalImage = gradioApp().getElementById("modalImage");
|
var modalImage = gradioApp().getElementById("modalImage");
|
||||||
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
@ -179,7 +175,7 @@ function galleryImageHandler(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onUiUpdate(function() {
|
||||||
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
||||||
if (fullImg_preview != null) {
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(setupImageForLightbox);
|
fullImg_preview.forEach(setupImageForLightbox);
|
||||||
}
|
}
|
||||||
@ -251,8 +247,11 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
|
|
||||||
modal.appendChild(modalNext)
|
modal.appendChild(modalNext)
|
||||||
|
|
||||||
gradioApp().appendChild(modal)
|
try {
|
||||||
|
gradioApp().appendChild(modal);
|
||||||
|
} catch (e) {
|
||||||
|
gradioApp().body.appendChild(modal);
|
||||||
|
}
|
||||||
|
|
||||||
document.body.appendChild(modal);
|
document.body.appendChild(modal);
|
||||||
|
|
||||||
|
57
javascript/imageviewerGamepad.js
Normal file
57
javascript/imageviewerGamepad.js
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
|
const index = e.gamepad.index;
|
||||||
|
let isWaiting = false;
|
||||||
|
setInterval(async () => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
|
const gamepad = navigator.getGamepads()[index];
|
||||||
|
const xValue = gamepad.axes[0];
|
||||||
|
if (xValue <= -0.3) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
isWaiting = true;
|
||||||
|
} else if (xValue >= 0.3) {
|
||||||
|
modalNextImage(e);
|
||||||
|
isWaiting = true;
|
||||||
|
}
|
||||||
|
if (isWaiting) {
|
||||||
|
await sleepUntil(() => {
|
||||||
|
const xValue = navigator.getGamepads()[index].axes[0]
|
||||||
|
if (xValue < 0.3 && xValue > -0.3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
isWaiting = false;
|
||||||
|
}
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
|
/*
|
||||||
|
Primarily for vr controller type pointer devices.
|
||||||
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
|
*/
|
||||||
|
let isScrolling = false;
|
||||||
|
window.addEventListener('wheel', (e) => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
|
||||||
|
isScrolling = true;
|
||||||
|
|
||||||
|
if (e.deltaX <= -0.6) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
} else if (e.deltaX >= 0.6) {
|
||||||
|
modalNextImage(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
isScrolling = false;
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
});
|
||||||
|
|
||||||
|
function sleepUntil(f, timeout) {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const timeStart = new Date();
|
||||||
|
const wait = setInterval(function() {
|
||||||
|
if (f() || new Date() - timeStart > timeout) {
|
||||||
|
clearInterval(wait);
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}, 20);
|
||||||
|
});
|
||||||
|
}
|
@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
|||||||
original_lines = {}
|
original_lines = {}
|
||||||
translated_lines = {}
|
translated_lines = {}
|
||||||
|
|
||||||
|
function hasLocalization() {
|
||||||
|
return window.localization && Object.keys(window.localization).length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
function textNodesUnder(el){
|
function textNodesUnder(el){
|
||||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||||
while(n=walk.nextNode()) a.push(n);
|
while(n=walk.nextNode()) a.push(n);
|
||||||
@ -35,11 +39,11 @@ function canBeTranslated(node, text){
|
|||||||
if(! text) return false;
|
if(! text) return false;
|
||||||
if(! node.parentElement) return false;
|
if(! node.parentElement) return false;
|
||||||
|
|
||||||
parentType = node.parentElement.nodeName
|
var parentType = node.parentElement.nodeName
|
||||||
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
||||||
|
|
||||||
if (parentType=='OPTION' || parentType=='SPAN'){
|
if (parentType=='OPTION' || parentType=='SPAN'){
|
||||||
pnode = node
|
var pnode = node
|
||||||
for(var level=0; level<4; level++){
|
for(var level=0; level<4; level++){
|
||||||
pnode = pnode.parentElement
|
pnode = pnode.parentElement
|
||||||
if(! pnode) break;
|
if(! pnode) break;
|
||||||
@ -69,7 +73,7 @@ function getTranslation(text){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function processTextNode(node){
|
function processTextNode(node){
|
||||||
text = node.textContent.trim()
|
var text = node.textContent.trim()
|
||||||
|
|
||||||
if(! canBeTranslated(node, text)) return
|
if(! canBeTranslated(node, text)) return
|
||||||
|
|
||||||
@ -105,7 +109,7 @@ function processNode(node){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function dumpTranslations(){
|
function dumpTranslations(){
|
||||||
dumped = {}
|
var dumped = {}
|
||||||
if (localization.rtl) {
|
if (localization.rtl) {
|
||||||
dumped.rtl = true
|
dumped.rtl = true
|
||||||
}
|
}
|
||||||
@ -119,39 +123,8 @@ function dumpTranslations(){
|
|||||||
return dumped
|
return dumped
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(m){
|
|
||||||
m.forEach(function(mutation){
|
|
||||||
mutation.addedNodes.forEach(function(node){
|
|
||||||
processNode(node)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
processNode(gradioApp())
|
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
|
||||||
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
|
||||||
mutations.forEach(mutation => {
|
|
||||||
mutation.addedNodes.forEach(node => {
|
|
||||||
if (node.tagName === 'STYLE') {
|
|
||||||
observer.disconnect();
|
|
||||||
|
|
||||||
for (const x of node.sheet.rules) { // find all rtl media rules
|
|
||||||
if (Array.from(x.media || []).includes('rtl')) {
|
|
||||||
x.media.appendMedium('all'); // enable them
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})).observe(gradioApp(), { childList: true });
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
function download_localization() {
|
function download_localization() {
|
||||||
text = JSON.stringify(dumpTranslations(), null, 4)
|
var text = JSON.stringify(dumpTranslations(), null, 4)
|
||||||
|
|
||||||
var element = document.createElement('a');
|
var element = document.createElement('a');
|
||||||
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||||
@ -163,3 +136,36 @@ function download_localization() {
|
|||||||
|
|
||||||
document.body.removeChild(element);
|
document.body.removeChild(element);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(hasLocalization()) {
|
||||||
|
onUiUpdate(function (m) {
|
||||||
|
m.forEach(function (mutation) {
|
||||||
|
mutation.addedNodes.forEach(function (node) {
|
||||||
|
processNode(node)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
processNode(gradioApp())
|
||||||
|
|
||||||
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
|
(new MutationObserver((mutations, observer) => { // wait for the style to load
|
||||||
|
mutations.forEach(mutation => {
|
||||||
|
mutation.addedNodes.forEach(node => {
|
||||||
|
if (node.tagName === 'STYLE') {
|
||||||
|
observer.disconnect();
|
||||||
|
|
||||||
|
for (const x of node.sheet.rules) { // find all rtl media rules
|
||||||
|
if (Array.from(x.media || []).includes('rtl')) {
|
||||||
|
x.media.appendMedium('all'); // enable them
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})).observe(gradioApp(), { childList: true });
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
@ -2,15 +2,15 @@
|
|||||||
|
|
||||||
let lastHeadImg = null;
|
let lastHeadImg = null;
|
||||||
|
|
||||||
notificationButton = null
|
let notificationButton = null;
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(notificationButton == null){
|
if(notificationButton == null){
|
||||||
notificationButton = gradioApp().getElementById('request_notifications')
|
notificationButton = gradioApp().getElementById('request_notifications')
|
||||||
|
|
||||||
if(notificationButton != null){
|
if(notificationButton != null){
|
||||||
notificationButton.addEventListener('click', function (evt) {
|
notificationButton.addEventListener('click', () => {
|
||||||
Notification.requestPermission();
|
void Notification.requestPermission();
|
||||||
},true);
|
},true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
|
|
||||||
function rememberGallerySelection(id_gallery){
|
function rememberGallerySelection(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGallerySelectedIndex(id_gallery){
|
function getGallerySelectedIndex(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function request(url, data, handler, errorHandler){
|
function request(url, data, handler, errorHandler){
|
||||||
var xhr = new XMLHttpRequest();
|
var xhr = new XMLHttpRequest();
|
||||||
var url = url;
|
|
||||||
xhr.open("POST", url, true);
|
xhr.open("POST", url, true);
|
||||||
xhr.setRequestHeader("Content-Type", "application/json");
|
xhr.setRequestHeader("Content-Type", "application/json");
|
||||||
xhr.onreadystatechange = function () {
|
xhr.onreadystatechange = function () {
|
||||||
@ -66,7 +65,7 @@ function randomId(){
|
|||||||
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
||||||
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
||||||
// calls onProgress every time there is a progress update
|
// calls onProgress every time there is a progress update
|
||||||
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
|
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
|
||||||
var dateStart = new Date()
|
var dateStart = new Date()
|
||||||
var wasEverActive = false
|
var wasEverActive = false
|
||||||
var parentProgressbar = progressbarContainer.parentNode
|
var parentProgressbar = progressbarContainer.parentNode
|
||||||
@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
divProgress.style.width = rect.width + "px";
|
divProgress.style.width = rect.width + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
progressText = ""
|
let progressText = ""
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||||
divInner.style.background = res.progress ? "" : "transparent"
|
divInner.style.background = res.progress ? "" : "transparent"
|
||||||
@ -138,7 +137,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if(elapsedFromStart > 5 && !res.queued && !res.active){
|
if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
|
||||||
removeProgressBar()
|
removeProgressBar()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
function set_theme(theme){
|
function set_theme(theme){
|
||||||
gradioURL = window.location.href
|
var gradioURL = window.location.href
|
||||||
if (!gradioURL.includes('?__theme=')) {
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
window.location.replace(gradioURL + '?__theme=' + theme);
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
}
|
}
|
||||||
@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
return [gallery[0]];
|
return [gallery[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
index = selected_gallery_index()
|
var index = selected_gallery_index()
|
||||||
|
|
||||||
if (index < 0 || index >= gallery.length){
|
if (index < 0 || index >= gallery.length){
|
||||||
// Use the first image in the gallery as the default
|
// Use the first image in the gallery as the default
|
||||||
@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function args_to_array(args){
|
function args_to_array(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -138,7 +138,7 @@ function get_img2img_tab_index() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function create_submit_args(args){
|
function create_submit_args(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -159,14 +159,24 @@ function showSubmitButtons(tabname, show){
|
|||||||
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function showRestoreProgressButton(tabname, show){
|
||||||
|
var button = gradioApp().getElementById(tabname + "_restore_progress")
|
||||||
|
if(! button) return
|
||||||
|
|
||||||
|
button.style.display = show ? "flex" : "none"
|
||||||
|
}
|
||||||
|
|
||||||
function submit(){
|
function submit(){
|
||||||
rememberGallerySelection('txt2img_gallery')
|
rememberGallerySelection('txt2img_gallery')
|
||||||
showSubmitButtons('txt2img', false)
|
showSubmitButtons('txt2img', false)
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
|
localStorage.setItem("txt2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||||
showSubmitButtons('txt2img', true)
|
showSubmitButtons('txt2img', true)
|
||||||
|
localStorage.removeItem("txt2img_task_id")
|
||||||
|
showRestoreProgressButton('txt2img', false)
|
||||||
})
|
})
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments)
|
||||||
@ -181,8 +191,12 @@ function submit_img2img(){
|
|||||||
showSubmitButtons('img2img', false)
|
showSubmitButtons('img2img', false)
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
|
localStorage.setItem("img2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||||
showSubmitButtons('img2img', true)
|
showSubmitButtons('img2img', true)
|
||||||
|
localStorage.removeItem("img2img_task_id")
|
||||||
|
showRestoreProgressButton('img2img', false)
|
||||||
})
|
})
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments)
|
||||||
@ -193,6 +207,42 @@ function submit_img2img(){
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function restoreProgressTxt2img(){
|
||||||
|
showRestoreProgressButton("txt2img", false)
|
||||||
|
var id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
|
id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
|
if(id) {
|
||||||
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||||
|
showSubmitButtons('txt2img', true)
|
||||||
|
}, null, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
function restoreProgressImg2img(){
|
||||||
|
showRestoreProgressButton("img2img", false)
|
||||||
|
|
||||||
|
var id = localStorage.getItem("img2img_task_id")
|
||||||
|
|
||||||
|
if(id) {
|
||||||
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||||
|
showSubmitButtons('img2img', true)
|
||||||
|
}, null, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
onUiLoaded(function () {
|
||||||
|
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"))
|
||||||
|
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"))
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
function modelmerger(){
|
function modelmerger(){
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||||
@ -204,7 +254,7 @@ function modelmerger(){
|
|||||||
|
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
name_ = prompt('Style name:')
|
var name_ = prompt('Style name:')
|
||||||
return [name_, prompt_text, negative_prompt_text]
|
return [name_, prompt_text, negative_prompt_text]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,11 +289,11 @@ function recalculate_prompts_img2img(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
var opts = {}
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
json_elem = gradioApp().getElementById('settings_json')
|
var json_elem = gradioApp().getElementById('settings_json')
|
||||||
if(json_elem == null) return;
|
if(json_elem == null) return;
|
||||||
|
|
||||||
var textarea = json_elem.querySelector('textarea')
|
var textarea = json_elem.querySelector('textarea')
|
||||||
@ -292,8 +342,8 @@ onUiUpdate(function(){
|
|||||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
||||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
||||||
|
|
||||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||||
settings_tabs = gradioApp().querySelector('#settings div')
|
var settings_tabs = gradioApp().querySelector('#settings div')
|
||||||
if(show_all_pages && settings_tabs){
|
if(show_all_pages && settings_tabs){
|
||||||
settings_tabs.appendChild(show_all_pages)
|
settings_tabs.appendChild(show_all_pages)
|
||||||
show_all_pages.onclick = function(){
|
show_all_pages.onclick = function(){
|
||||||
@ -305,9 +355,9 @@ onUiUpdate(function(){
|
|||||||
})
|
})
|
||||||
|
|
||||||
onOptionsChanged(function(){
|
onOptionsChanged(function(){
|
||||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
var elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||||
shorthash = sd_checkpoint_hash.substr(0,10)
|
var shorthash = sd_checkpoint_hash.substring(0,10)
|
||||||
|
|
||||||
if(elem && elem.textContent != shorthash){
|
if(elem && elem.textContent != shorthash){
|
||||||
elem.textContent = shorthash
|
elem.textContent = shorthash
|
||||||
@ -361,3 +411,19 @@ function selectCheckpoint(name){
|
|||||||
desiredCheckpointName = name;
|
desiredCheckpointName = name;
|
||||||
gradioApp().getElementById('change_checkpoint').click()
|
gradioApp().getElementById('change_checkpoint').click()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function currentImg2imgSourceResolution(_, _, scaleBy){
|
||||||
|
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
|
||||||
|
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateImg2imgResizeToTextAfterChangingImage(){
|
||||||
|
// At the time this is called from gradio, the image has no yet been replaced.
|
||||||
|
// There may be a better solution, but this is simple and straightforward so I'm going with it.
|
||||||
|
|
||||||
|
setTimeout(function() {
|
||||||
|
gradioApp().getElementById('img2img_update_resize_to').click()
|
||||||
|
}, 500);
|
||||||
|
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
19
launch.py
19
launch.py
@ -19,7 +19,6 @@ python = sys.executable
|
|||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
stored_commit_hash = None
|
stored_commit_hash = None
|
||||||
skip_install = False
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
@ -49,7 +48,7 @@ or any other error regarding unsuccessful package (library) installation,
|
|||||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||||
and delete current Python and "venv" folder in WebUI's directory.
|
and delete current Python and "venv" folder in WebUI's directory.
|
||||||
|
|
||||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||||
|
|
||||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||||
|
|
||||||
@ -121,12 +120,12 @@ def run_python(code, desc=None, errdesc=None):
|
|||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
def run_pip(command, desc=None, live=False):
|
||||||
if skip_install:
|
if args.skip_install:
|
||||||
return
|
return
|
||||||
|
|
||||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code):
|
||||||
@ -223,12 +222,10 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
global skip_install
|
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118")
|
||||||
|
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
||||||
@ -271,7 +268,7 @@ def prepare_environment():
|
|||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
if platform.python_version().startswith("3.10"):
|
if platform.python_version().startswith("3.10"):
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
||||||
else:
|
else:
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
@ -296,7 +293,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
if not os.path.isfile(requirements_file):
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
|
@ -6,7 +6,6 @@ import uvicorn
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from gradio.processing_utils import decode_base64_to_file
|
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
@ -131,8 +130,8 @@ def api_middleware(app: FastAPI):
|
|||||||
"body": vars(e).get('body', ''),
|
"body": vars(e).get('body', ''),
|
||||||
"errors": str(e),
|
"errors": str(e),
|
||||||
}
|
}
|
||||||
print(f"API error: {request.method}: {request.url} {err}")
|
|
||||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||||
|
print(f"API error: {request.method}: {request.url} {err}")
|
||||||
if rich_available:
|
if rich_available:
|
||||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||||
else:
|
else:
|
||||||
@ -272,7 +271,9 @@ class Api:
|
|||||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
||||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
# min between arg length in scriptrunner and arg length in the request
|
||||||
|
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
||||||
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
||||||
@ -395,16 +396,11 @@ class Api:
|
|||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
def prepareFiles(file):
|
image_list = reqDict.pop('imageList', [])
|
||||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||||
file.orig_name = file.name
|
|
||||||
return file
|
|
||||||
|
|
||||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
|
||||||
reqDict.pop('imageList')
|
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
|
progress.record_results(id_task, res)
|
||||||
finally:
|
finally:
|
||||||
progress.finish_task(id_task)
|
progress.finish_task(id_task)
|
||||||
|
|
||||||
|
@ -95,9 +95,11 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
|
|||||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||||
|
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
200
modules/config_states.py
Normal file
200
modules/config_states.py
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
Supports saving and restoring webui and extensions from a known working set of commits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import OrderedDict
|
||||||
|
import git
|
||||||
|
|
||||||
|
from modules import shared, extensions
|
||||||
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
|
all_config_states = OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def list_config_states():
|
||||||
|
global all_config_states
|
||||||
|
|
||||||
|
all_config_states.clear()
|
||||||
|
os.makedirs(config_states_dir, exist_ok=True)
|
||||||
|
|
||||||
|
config_states = []
|
||||||
|
for filename in os.listdir(config_states_dir):
|
||||||
|
if filename.endswith(".json"):
|
||||||
|
path = os.path.join(config_states_dir, filename)
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
j = json.load(f)
|
||||||
|
j["filepath"] = path
|
||||||
|
config_states.append(j)
|
||||||
|
|
||||||
|
config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True))
|
||||||
|
|
||||||
|
for cs in config_states:
|
||||||
|
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||||
|
name = cs.get("name", "Config")
|
||||||
|
full_name = f"{name}: {timestamp}"
|
||||||
|
all_config_states[full_name] = cs
|
||||||
|
|
||||||
|
return all_config_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_webui_config():
|
||||||
|
webui_repo = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
|
webui_repo = git.Repo(script_path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
webui_remote = None
|
||||||
|
webui_commit_hash = None
|
||||||
|
webui_commit_date = None
|
||||||
|
webui_branch = None
|
||||||
|
if webui_repo and not webui_repo.bare:
|
||||||
|
try:
|
||||||
|
webui_remote = next(webui_repo.remote().urls, None)
|
||||||
|
head = webui_repo.head.commit
|
||||||
|
webui_commit_date = webui_repo.head.commit.committed_date
|
||||||
|
webui_commit_hash = head.hexsha
|
||||||
|
webui_branch = webui_repo.active_branch.name
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
webui_remote = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"remote": webui_remote,
|
||||||
|
"commit_hash": webui_commit_hash,
|
||||||
|
"commit_date": webui_commit_date,
|
||||||
|
"branch": webui_branch,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_extension_config():
|
||||||
|
ext_config = {}
|
||||||
|
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
entry = {
|
||||||
|
"name": ext.name,
|
||||||
|
"path": ext.path,
|
||||||
|
"enabled": ext.enabled,
|
||||||
|
"is_builtin": ext.is_builtin,
|
||||||
|
"remote": ext.remote,
|
||||||
|
"commit_hash": ext.commit_hash,
|
||||||
|
"commit_date": ext.commit_date,
|
||||||
|
"branch": ext.branch,
|
||||||
|
"have_info_from_repo": ext.have_info_from_repo
|
||||||
|
}
|
||||||
|
|
||||||
|
ext_config[ext.name] = entry
|
||||||
|
|
||||||
|
return ext_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
creation_time = datetime.now().timestamp()
|
||||||
|
webui_config = get_webui_config()
|
||||||
|
ext_config = get_extension_config()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"created_at": creation_time,
|
||||||
|
"webui": webui_config,
|
||||||
|
"extensions": ext_config
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def restore_webui_config(config):
|
||||||
|
print("* Restoring webui state...")
|
||||||
|
|
||||||
|
if "webui" not in config:
|
||||||
|
print("Error: No webui data saved to config")
|
||||||
|
return
|
||||||
|
|
||||||
|
webui_config = config["webui"]
|
||||||
|
|
||||||
|
if "commit_hash" not in webui_config:
|
||||||
|
print("Error: No commit saved to webui config")
|
||||||
|
return
|
||||||
|
|
||||||
|
webui_commit_hash = webui_config.get("commit_hash", None)
|
||||||
|
webui_repo = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
|
webui_repo = git.Repo(script_path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
webui_repo.git.fetch(all=True)
|
||||||
|
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||||
|
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||||
|
except Exception:
|
||||||
|
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_extension_config(config):
|
||||||
|
print("* Restoring extension state...")
|
||||||
|
|
||||||
|
if "extensions" not in config:
|
||||||
|
print("Error: No extension data saved to config")
|
||||||
|
return
|
||||||
|
|
||||||
|
ext_config = config["extensions"]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
disabled = []
|
||||||
|
|
||||||
|
for ext in tqdm.tqdm(extensions.extensions):
|
||||||
|
if ext.is_builtin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
current_commit = ext.commit_hash
|
||||||
|
|
||||||
|
if ext.name not in ext_config:
|
||||||
|
ext.disabled = True
|
||||||
|
disabled.append(ext.name)
|
||||||
|
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry = ext_config[ext.name]
|
||||||
|
|
||||||
|
if "commit_hash" in entry and entry["commit_hash"]:
|
||||||
|
try:
|
||||||
|
ext.fetch_and_reset_hard(entry["commit_hash"])
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
if current_commit != entry["commit_hash"]:
|
||||||
|
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
||||||
|
except Exception as ex:
|
||||||
|
results.append((ext, current_commit[:8], False, ex))
|
||||||
|
else:
|
||||||
|
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
||||||
|
|
||||||
|
if not entry.get("enabled", False):
|
||||||
|
ext.disabled = True
|
||||||
|
disabled.append(ext.name)
|
||||||
|
else:
|
||||||
|
ext.disabled = False
|
||||||
|
|
||||||
|
shared.opts.disabled_extensions = disabled
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
print("* Finished restoring extensions. Results:")
|
||||||
|
for ext, prev_commit, success, result in results:
|
||||||
|
if success:
|
||||||
|
print(f" + {ext.name}: {prev_commit} -> {result}")
|
||||||
|
else:
|
||||||
|
print(f" ! {ext.name}: FAILURE ({result})")
|
@ -92,14 +92,18 @@ def cond_cast_float(input):
|
|||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if device.type == 'mps':
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
def randn_without_seed(shape):
|
||||||
if device.type == 'mps':
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
@ -3,10 +3,11 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
@ -31,12 +32,15 @@ class Extension:
|
|||||||
self.status = ''
|
self.status = ''
|
||||||
self.can_update = False
|
self.can_update = False
|
||||||
self.is_builtin = is_builtin
|
self.is_builtin = is_builtin
|
||||||
|
self.commit_hash = ''
|
||||||
|
self.commit_date = None
|
||||||
self.version = ''
|
self.version = ''
|
||||||
|
self.branch = None
|
||||||
self.remote = None
|
self.remote = None
|
||||||
self.have_info_from_repo = False
|
self.have_info_from_repo = False
|
||||||
|
|
||||||
def read_info_from_repo(self):
|
def read_info_from_repo(self):
|
||||||
if self.have_info_from_repo:
|
if self.is_builtin or self.have_info_from_repo:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.have_info_from_repo = True
|
self.have_info_from_repo = True
|
||||||
@ -56,10 +60,15 @@ class Extension:
|
|||||||
self.status = 'unknown'
|
self.status = 'unknown'
|
||||||
self.remote = next(repo.remote().urls, None)
|
self.remote = next(repo.remote().urls, None)
|
||||||
head = repo.head.commit
|
head = repo.head.commit
|
||||||
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
self.commit_date = repo.head.commit.committed_date
|
||||||
self.version = f'{head.hexsha[:8]} ({ts})'
|
ts = time.asctime(time.gmtime(self.commit_date))
|
||||||
|
if repo.active_branch:
|
||||||
|
self.branch = repo.active_branch.name
|
||||||
|
self.commit_hash = head.hexsha
|
||||||
|
self.version = f'{self.commit_hash[:8]} ({ts})'
|
||||||
|
|
||||||
except Exception:
|
except Exception as ex:
|
||||||
|
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
||||||
self.remote = None
|
self.remote = None
|
||||||
|
|
||||||
def list_files(self, subdir, extension):
|
def list_files(self, subdir, extension):
|
||||||
@ -82,18 +91,30 @@ class Extension:
|
|||||||
for fetch in repo.remote().fetch(dry_run=True):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
self.status = "behind"
|
self.status = "new commits"
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
origin = repo.rev_parse('origin')
|
||||||
|
if repo.head.commit != origin:
|
||||||
|
self.can_update = True
|
||||||
|
self.status = "behind HEAD"
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
self.can_update = False
|
||||||
|
self.status = "unknown (remote error)"
|
||||||
|
return
|
||||||
|
|
||||||
self.can_update = False
|
self.can_update = False
|
||||||
self.status = "latest"
|
self.status = "latest"
|
||||||
|
|
||||||
def fetch_and_reset_hard(self):
|
def fetch_and_reset_hard(self, commit='origin'):
|
||||||
repo = git.Repo(self.path)
|
repo = git.Repo(self.path)
|
||||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
repo.git.fetch(all=True)
|
repo.git.fetch(all=True)
|
||||||
repo.git.reset('origin', hard=True)
|
repo.git.reset(commit, hard=True)
|
||||||
|
self.have_info_from_repo = False
|
||||||
|
|
||||||
|
|
||||||
def list_extensions():
|
def list_extensions():
|
||||||
|
@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
|||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_hypernetwork
|
additional = shared.opts.sd_hypernetwork
|
||||||
|
|
||||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -71,7 +72,7 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
|
metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}
|
||||||
|
|
||||||
|
if save_metadata:
|
||||||
|
merge_recipe = {
|
||||||
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
|
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
||||||
|
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
||||||
|
"interp_method": interp_method,
|
||||||
|
"multiplier": multiplier,
|
||||||
|
"save_as_half": save_as_half,
|
||||||
|
"custom_name": custom_name,
|
||||||
|
"config_source": config_source,
|
||||||
|
"bake_in_vae": bake_in_vae,
|
||||||
|
"discard_weights": discard_weights,
|
||||||
|
"is_inpainting": result_is_inpainting_model,
|
||||||
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
|
}
|
||||||
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
|
|
||||||
|
def add_model_metadata(checkpoint_info):
|
||||||
|
checkpoint_info.calculate_shorthash()
|
||||||
|
metadata["sd_merge_models"][checkpoint_info.sha256] = {
|
||||||
|
"name": checkpoint_info.name,
|
||||||
|
"legacy_hash": checkpoint_info.hash,
|
||||||
|
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||||
|
|
||||||
|
add_model_metadata(primary_model_info)
|
||||||
|
if secondary_model_info:
|
||||||
|
add_model_metadata(secondary_model_info)
|
||||||
|
if tertiary_model_info:
|
||||||
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
||||||
|
if created_model:
|
||||||
|
created_model.calculate_shorthash()
|
||||||
|
|
||||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
|
@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
|
# Missing RNG means the default was set, which is GPU RNG
|
||||||
|
if "RNG" not in res:
|
||||||
|
res["RNG"] = "GPU"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -304,6 +308,8 @@ infotext_to_setting_name_mapping = [
|
|||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
('UniPC order', 'uni_pc_order'),
|
('UniPC order', 'uni_pc_order'),
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
('RNG', 'randn_source'),
|
||||||
|
('NGMS', 's_min_uncond'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -318,6 +318,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
|||||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||||
max_filename_part_length = 128
|
max_filename_part_length = 128
|
||||||
|
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||||
|
|
||||||
|
|
||||||
def sanitize_filename_part(text, replace_spaces=True):
|
def sanitize_filename_part(text, replace_spaces=True):
|
||||||
@ -352,6 +353,11 @@ class FilenameGenerator:
|
|||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
'prompt_words': lambda self: self.prompt_words(),
|
'prompt_words': lambda self: self.prompt_words(),
|
||||||
|
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
|
||||||
|
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||||
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -360,6 +366,22 @@ class FilenameGenerator:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
def hasprompt(self, *args):
|
||||||
|
lower = self.prompt.lower()
|
||||||
|
if self.p is None or self.prompt is None:
|
||||||
|
return None
|
||||||
|
outres = ""
|
||||||
|
for arg in args:
|
||||||
|
if arg != "":
|
||||||
|
division = arg.split("|")
|
||||||
|
expected = division[0].lower()
|
||||||
|
default = division[1] if len(division) > 1 else ""
|
||||||
|
if lower.find(expected) >= 0:
|
||||||
|
outres = f'{outres}{expected}'
|
||||||
|
else:
|
||||||
|
outres = outres if default == "" else f'{outres}{default}'
|
||||||
|
return sanitize_filename_part(outres)
|
||||||
|
|
||||||
def prompt_no_style(self):
|
def prompt_no_style(self):
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
@ -403,9 +425,9 @@ class FilenameGenerator:
|
|||||||
|
|
||||||
for m in re_pattern.finditer(x):
|
for m in re_pattern.finditer(x):
|
||||||
text, pattern = m.groups()
|
text, pattern = m.groups()
|
||||||
res += text
|
|
||||||
|
|
||||||
if pattern is None:
|
if pattern is None:
|
||||||
|
res += text
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pattern_args = []
|
pattern_args = []
|
||||||
@ -426,11 +448,13 @@ class FilenameGenerator:
|
|||||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if replacement is not None:
|
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||||
res += str(replacement)
|
continue
|
||||||
|
elif replacement is not None:
|
||||||
|
res += text + str(replacement)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res += f'[{pattern}]'
|
res += f'{text}[{pattern}]'
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import devices, sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
@ -46,7 +46,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
img = Image.open(image)
|
try:
|
||||||
|
img = Image.open(image)
|
||||||
|
except UnidentifiedImageError:
|
||||||
|
continue
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
# Use the EXIF orientation of photos taken by smartphones.
|
||||||
img = ImageOps.exif_transpose(img)
|
img = ImageOps.exif_transpose(img)
|
||||||
p.init_images = [img] * p.batch_size
|
p.init_images = [img] * p.batch_size
|
||||||
@ -78,7 +81,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -114,6 +117,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
if image is not None:
|
if image is not None:
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
|
|
||||||
|
if selected_scale_tab == 1:
|
||||||
|
assert image, "Can't scale by because no image is selected"
|
||||||
|
|
||||||
|
width = int(image.width * scale_by)
|
||||||
|
height = int(image.height * scale_by)
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
|
||||||
p = StableDiffusionProcessingImg2Img(
|
p = StableDiffusionProcessingImg2Img(
|
||||||
@ -151,7 +160,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_img2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
|
@ -32,7 +32,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
|||||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(tmpdir)
|
os.makedirs(tmpdir, exist_ok=True)
|
||||||
for category_type in category_types:
|
for category_type in category_types:
|
||||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||||
os.rename(tmpdir, content_dir)
|
os.rename(tmpdir, content_dir)
|
||||||
@ -41,7 +41,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
|||||||
errors.display(e, "downloading default CLIP interrogate categories")
|
errors.display(e, "downloading default CLIP interrogate categories")
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(tmpdir):
|
if os.path.exists(tmpdir):
|
||||||
os.remove(tmpdir)
|
os.removedirs(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
class InterrogateModels:
|
class InterrogateModels:
|
||||||
|
@ -13,6 +13,18 @@ def connect(token, port, region):
|
|||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Guard for existing tunnels
|
||||||
|
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||||
|
if existing:
|
||||||
|
for established in existing:
|
||||||
|
# Extra configuration in the case that the user is also using ngrok for other tunnels
|
||||||
|
if established.config['addr'][-4:] == str(port):
|
||||||
|
public_url = existing[0].public_url
|
||||||
|
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||||
|
'You can use this link after the launch is complete.')
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if account is None:
|
if account is None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir
|
|||||||
models_path = os.path.join(data_path, "models")
|
models_path = os.path.join(data_path, "models")
|
||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
image = Image.open(img)
|
if isinstance(img, Image.Image):
|
||||||
|
image = img
|
||||||
|
fn = ''
|
||||||
|
else:
|
||||||
|
image = Image.open(os.path.abspath(img.name))
|
||||||
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_data.append(image)
|
image_data.append(image)
|
||||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
image_names.append(fn)
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
|
@ -3,6 +3,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -105,7 +106,7 @@ class StableDiffusionProcessing:
|
|||||||
"""
|
"""
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
"""
|
"""
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||||
if sampler_index is not None:
|
if sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
@ -140,6 +141,7 @@ class StableDiffusionProcessing:
|
|||||||
self.denoising_strength: float = denoising_strength
|
self.denoising_strength: float = denoising_strength
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||||
|
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||||
self.s_churn = s_churn or opts.s_churn
|
self.s_churn = s_churn or opts.s_churn
|
||||||
self.s_tmin = s_tmin or opts.s_tmin
|
self.s_tmin = s_tmin or opts.s_tmin
|
||||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||||
@ -162,6 +164,8 @@ class StableDiffusionProcessing:
|
|||||||
self.all_seeds = None
|
self.all_seeds = None
|
||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
|
self.is_hr_pass = False
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
@ -476,6 +480,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
||||||
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
generation_params.update(p.extra_generation_params)
|
||||||
@ -491,6 +498,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
|
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||||
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
for k, v in p.override_settings.items():
|
for k, v in p.override_settings.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
|
|
||||||
@ -507,8 +519,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
if k == 'sd_model_checkpoint':
|
|
||||||
sd_models.reload_model_weights()
|
|
||||||
|
|
||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
@ -639,8 +649,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
step_multiplier = 1
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
if not shared.opts.dont_fix_second_order_samplers_schedule:
|
||||||
|
try:
|
||||||
|
step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
@ -670,6 +686,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
|
p.batch_index = i
|
||||||
|
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
@ -706,9 +724,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
|
|
||||||
if opts.save_mask:
|
if opts.save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||||
@ -718,7 +736,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
if opts.return_mask:
|
if opts.return_mask:
|
||||||
output_images.append(image_mask)
|
output_images.append(image_mask)
|
||||||
|
|
||||||
if opts.return_mask_composite:
|
if opts.return_mask_composite:
|
||||||
output_images.append(image_mask_composite)
|
output_images.append(image_mask_composite)
|
||||||
|
|
||||||
@ -871,6 +889,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
@ -940,6 +960,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -1007,6 +1029,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.color_corrections = []
|
self.color_corrections = []
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in self.init_images:
|
for img in self.init_images:
|
||||||
|
|
||||||
|
# Save init image
|
||||||
|
if opts.save_init_img:
|
||||||
|
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||||
|
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||||
|
|
||||||
image = images.flatten(img, opts.img2img_background_color)
|
image = images.flatten(img, opts.img2img_background_color)
|
||||||
|
|
||||||
if crop_region is None and self.resize_mode != 3:
|
if crop_region is None and self.resize_mode != 3:
|
||||||
|
@ -13,6 +13,8 @@ import modules.shared as shared
|
|||||||
current_task = None
|
current_task = None
|
||||||
pending_tasks = {}
|
pending_tasks = {}
|
||||||
finished_tasks = []
|
finished_tasks = []
|
||||||
|
recorded_results = []
|
||||||
|
recorded_results_limit = 2
|
||||||
|
|
||||||
|
|
||||||
def start_task(id_task):
|
def start_task(id_task):
|
||||||
@ -33,6 +35,12 @@ def finish_task(id_task):
|
|||||||
finished_tasks.pop(0)
|
finished_tasks.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def record_results(id_task, res):
|
||||||
|
recorded_results.append((id_task, res))
|
||||||
|
if len(recorded_results) > recorded_results_limit:
|
||||||
|
recorded_results.pop(0)
|
||||||
|
|
||||||
|
|
||||||
def add_task_to_queue(id_job):
|
def add_task_to_queue(id_job):
|
||||||
pending_tasks[id_job] = time.time()
|
pending_tasks[id_job] = time.time()
|
||||||
|
|
||||||
@ -97,3 +105,13 @@ def progressapi(req: ProgressRequest):
|
|||||||
|
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_progress(id_task):
|
||||||
|
while id_task == current_task or id_task in pending_tasks:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
|
||||||
|
return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
|
||||||
|
@ -9,7 +9,7 @@ from realesrgan import RealESRGANer
|
|||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
|
from modules import modelloader
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
@ -23,7 +23,15 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = self.load_models(path)
|
||||||
|
|
||||||
|
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||||
for scaler in scalers:
|
for scaler in scalers:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
|
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||||
|
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None)
|
||||||
|
if local:
|
||||||
|
scaler.local_data_path = local
|
||||||
|
|
||||||
if scaler.name in opts.realesrgan_enabled_models:
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
print(f"Unable to find model info: {path}")
|
print(f"Unable to find model info: {path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
if info.local_data_path.startswith("http"):
|
||||||
|
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||||
|
|
||||||
return info
|
return info
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# this code is adapted from the script contributed by anon from /h/
|
# this code is adapted from the script contributed by anon from /h/
|
||||||
|
|
||||||
import io
|
|
||||||
import pickle
|
import pickle
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
@ -12,11 +11,9 @@ import _codecs
|
|||||||
import zipfile
|
import zipfile
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
|
|
||||||
def encode(*args):
|
def encode(*args):
|
||||||
out = _codecs.encode(*args)
|
out = _codecs.encode(*args)
|
||||||
return out
|
return out
|
||||||
@ -27,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
|
|
||||||
def persistent_load(self, saved_id):
|
def persistent_load(self, saved_id):
|
||||||
assert saved_id[0] == 'storage'
|
assert saved_id[0] == 'storage'
|
||||||
return TypedStorage()
|
|
||||||
|
try:
|
||||||
|
return TypedStorage(_internal=True)
|
||||||
|
except TypeError:
|
||||||
|
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||||
|
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
if self.extra_handler is not None:
|
if self.extra_handler is not None:
|
||||||
|
@ -93,6 +93,7 @@ callback_map = dict(
|
|||||||
callbacks_infotext_pasted=[],
|
callbacks_infotext_pasted=[],
|
||||||
callbacks_script_unloaded=[],
|
callbacks_script_unloaded=[],
|
||||||
callbacks_before_ui=[],
|
callbacks_before_ui=[],
|
||||||
|
callbacks_on_reload=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -109,6 +110,14 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
|||||||
report_exception(c, 'app_started_callback')
|
report_exception(c, 'app_started_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def app_reload_callback():
|
||||||
|
for c in callback_map['callbacks_on_reload']:
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_on_reload')
|
||||||
|
|
||||||
|
|
||||||
def model_loaded_callback(sd_model):
|
def model_loaded_callback(sd_model):
|
||||||
for c in callback_map['callbacks_model_loaded']:
|
for c in callback_map['callbacks_model_loaded']:
|
||||||
try:
|
try:
|
||||||
@ -254,6 +263,11 @@ def on_app_started(callback):
|
|||||||
add_callback(callback_map['callbacks_app_started'], callback)
|
add_callback(callback_map['callbacks_app_started'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_before_reload(callback):
|
||||||
|
"""register a function to be called just before the server reloads."""
|
||||||
|
add_callback(callback_map['callbacks_on_reload'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_model_loaded(callback):
|
def on_model_loaded(callback):
|
||||||
"""register a function to be called when the stable diffusion model is created; the model is
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
passed as an argument; this function is also called when the script is reloaded. """
|
passed as an argument; this function is also called when the script is reloaded. """
|
||||||
|
@ -2,6 +2,8 @@ import collections
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -52,6 +54,15 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(self.filename)
|
||||||
|
if ext.lower() == ".safetensors":
|
||||||
|
try:
|
||||||
|
self.metadata = read_metadata_from_safetensors(filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
@ -395,13 +406,39 @@ def repair_config(sd_config):
|
|||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
|
||||||
|
class SdModelData:
|
||||||
|
def __init__(self):
|
||||||
|
self.sd_model = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_sd_model(self):
|
||||||
|
if self.sd_model is None:
|
||||||
|
with self.lock:
|
||||||
|
try:
|
||||||
|
load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||||
|
self.sd_model = None
|
||||||
|
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
|
def set_sd_model(self, v):
|
||||||
|
self.sd_model = v
|
||||||
|
|
||||||
|
|
||||||
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
shared.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -455,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
@ -475,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
@ -503,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
checkpoints_loaded.clear()
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return shared.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
@ -526,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import lowvram, devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
|
model_data.sd_model.to(devices.cpu)
|
||||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
model_data.sd_model = None
|
||||||
shared.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
|
||||||
shared.sd_model = None
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@ -544,4 +579,4 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
@ -60,3 +60,13 @@ def store_latent(decoded):
|
|||||||
|
|
||||||
class InterruptedException(BaseException):
|
class InterruptedException(BaseException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if opts.randn_source == "CPU":
|
||||||
|
import torchsde._brownian.brownian_interval
|
||||||
|
|
||||||
|
def torchsde_randn(size, dtype, device, seed):
|
||||||
|
generator = torch.Generator(devices.cpu).manual_seed(int(seed))
|
||||||
|
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
|
||||||
|
|
||||||
|
torchsde._brownian.brownian_interval._randn = torchsde_randn
|
||||||
|
@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
|
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
raise sd_samplers_common.InterruptedException
|
raise sd_samplers_common.InterruptedException
|
||||||
|
|
||||||
@ -115,12 +115,21 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
sigma_in = denoiser_params.sigma
|
sigma_in = denoiser_params.sigma
|
||||||
tensor = denoiser_params.text_cond
|
tensor = denoiser_params.text_cond
|
||||||
uncond = denoiser_params.text_uncond
|
uncond = denoiser_params.text_uncond
|
||||||
|
skip_uncond = False
|
||||||
|
|
||||||
if tensor.shape[1] == uncond.shape[1]:
|
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||||
if not is_edit_model:
|
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||||
cond_in = torch.cat([tensor, uncond])
|
skip_uncond = True
|
||||||
else:
|
x_in = x_in[:-batch_size]
|
||||||
|
sigma_in = sigma_in[:-batch_size]
|
||||||
|
|
||||||
|
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||||
|
if is_edit_model:
|
||||||
cond_in = torch.cat([tensor, uncond, uncond])
|
cond_in = torch.cat([tensor, uncond, uncond])
|
||||||
|
elif skip_uncond:
|
||||||
|
cond_in = tensor
|
||||||
|
else:
|
||||||
|
cond_in = torch.cat([tensor, uncond])
|
||||||
|
|
||||||
if shared.batch_cond_uncond:
|
if shared.batch_cond_uncond:
|
||||||
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))
|
||||||
@ -144,7 +153,13 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
|
|
||||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||||
|
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
if not skip_uncond:
|
||||||
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
|
if skip_uncond:
|
||||||
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
@ -152,20 +167,21 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
devices.test_for_nans(x_out, "unet")
|
devices.test_for_nans(x_out, "unet")
|
||||||
|
|
||||||
if opts.live_preview_content == "Prompt":
|
if opts.live_preview_content == "Prompt":
|
||||||
sd_samplers_common.store_latent(x_out[0:uncond.shape[0]])
|
sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
|
||||||
elif opts.live_preview_content == "Negative prompt":
|
elif opts.live_preview_content == "Negative prompt":
|
||||||
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
|
||||||
|
|
||||||
if not is_edit_model:
|
if is_edit_model:
|
||||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
|
||||||
else:
|
|
||||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||||
|
elif skip_uncond:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||||
|
else:
|
||||||
|
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||||
|
|
||||||
if self.mask is not None:
|
if self.mask is not None:
|
||||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||||
|
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
return denoised
|
return denoised
|
||||||
|
|
||||||
|
|
||||||
@ -190,7 +206,7 @@ class TorchHijack:
|
|||||||
if noise.shape == x.shape:
|
if noise.shape == x.shape:
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
if x.device.type == 'mps':
|
if opts.randn_source == "CPU" or x.device.type == 'mps':
|
||||||
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
return torch.randn_like(x, device=devices.cpu).to(x.device)
|
||||||
else:
|
else:
|
||||||
return torch.randn_like(x)
|
return torch.randn_like(x)
|
||||||
@ -210,6 +226,7 @@ class KDiffusionSampler:
|
|||||||
self.eta = None
|
self.eta = None
|
||||||
self.config = None
|
self.config = None
|
||||||
self.last_latent = None
|
self.last_latent = None
|
||||||
|
self.s_min_uncond = None
|
||||||
|
|
||||||
self.conditioning_key = sd_model.model.conditioning_key
|
self.conditioning_key = sd_model.model.conditioning_key
|
||||||
|
|
||||||
@ -244,6 +261,7 @@ class KDiffusionSampler:
|
|||||||
self.model_wrap_cfg.step = 0
|
self.model_wrap_cfg.step = 0
|
||||||
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
|
||||||
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
|
||||||
|
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
|
||||||
|
|
||||||
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
|
||||||
|
|
||||||
@ -326,6 +344,7 @@ class KDiffusionSampler:
|
|||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale,
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
}
|
}
|
||||||
|
|
||||||
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
@ -359,7 +378,8 @@ class KDiffusionSampler:
|
|||||||
'cond': conditioning,
|
'cond': conditioning,
|
||||||
'image_cond': image_conditioning,
|
'image_cond': image_conditioning,
|
||||||
'uncond': unconditional_conditioning,
|
'uncond': unconditional_conditioning,
|
||||||
'cond_scale': p.cfg_scale
|
'cond_scale': p.cfg_scale,
|
||||||
|
's_min_uncond': self.s_min_uncond
|
||||||
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
@ -4,6 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
import requests
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -15,6 +16,7 @@ import modules.styles
|
|||||||
import modules.devices as devices
|
import modules.devices as devices
|
||||||
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
|
||||||
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
|
||||||
|
from ldm.models.diffusion.ddpm import LatentDiffusion
|
||||||
|
|
||||||
demo = None
|
demo = None
|
||||||
|
|
||||||
@ -39,6 +41,7 @@ restricted_opts = {
|
|||||||
"outdir_grids",
|
"outdir_grids",
|
||||||
"outdir_txt2img_grids",
|
"outdir_txt2img_grids",
|
||||||
"outdir_save",
|
"outdir_save",
|
||||||
|
"outdir_init_images"
|
||||||
}
|
}
|
||||||
|
|
||||||
ui_reorder_categories = [
|
ui_reorder_categories = [
|
||||||
@ -54,6 +57,21 @@ ui_reorder_categories = [
|
|||||||
"scripts",
|
"scripts",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
|
||||||
|
gradio_hf_hub_themes = [
|
||||||
|
"gradio/glass",
|
||||||
|
"gradio/monochrome",
|
||||||
|
"gradio/seafoam",
|
||||||
|
"gradio/soft",
|
||||||
|
"freddyaboulton/dracula_revamped",
|
||||||
|
"gradio/dracula_test",
|
||||||
|
"abidlabs/dracula_test",
|
||||||
|
"abidlabs/pakistan",
|
||||||
|
"dawood/microsoft_windows",
|
||||||
|
"ysharma/steampunk"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
|
||||||
|
|
||||||
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
|
||||||
@ -252,7 +270,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
|
|||||||
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
"use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"),
|
||||||
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
"use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"),
|
||||||
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
"save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
|
||||||
"do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),
|
"save_init_img": OptionInfo(False, "Save init images when using img2img"),
|
||||||
|
|
||||||
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
"temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"),
|
||||||
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
"clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),
|
||||||
@ -268,6 +286,7 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), {
|
|||||||
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
|
||||||
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
|
||||||
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
|
||||||
|
"outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
|
||||||
@ -283,6 +302,8 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
|
|||||||
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
|
||||||
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}),
|
||||||
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
|
||||||
|
"SCUNET_tile": OptionInfo(256, "Tile size for SCUNET upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
|
||||||
|
"SCUNET_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SCUNET upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
options_templates.update(options_section(('face-restoration', "Face restoration"), {
|
||||||
@ -331,6 +352,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
|
|||||||
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
"comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
|
||||||
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
|
||||||
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
|
||||||
|
"randn_source": OptionInfo("GPU", "Random number generator source. Changes seeds drastically. Use CPU to produce the same picture across different vidocard vendors.", gr.Radio, {"choices": ["GPU", "CPU"]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
options_templates.update(options_section(('compatibility', "Compatibility"), {
|
||||||
@ -338,6 +360,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
|
|||||||
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
"use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."),
|
||||||
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
"no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."),
|
||||||
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
|
||||||
|
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
|
||||||
@ -361,7 +384,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), {
|
|||||||
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
"extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks (px)"),
|
||||||
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
"extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks (px)"),
|
||||||
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
"extra_networks_add_text_separator": OptionInfo(" ", "Extra text to add before <...> when adding extra network to prompt"),
|
||||||
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "User interface"), {
|
options_templates.update(options_section(('ui', "User interface"), {
|
||||||
@ -377,16 +400,20 @@ options_templates.update(options_section(('ui', "User interface"), {
|
|||||||
"font": OptionInfo("", "Font for image grids that have text"),
|
"font": OptionInfo("", "Font for image grids that have text"),
|
||||||
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
|
||||||
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
|
||||||
|
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
|
||||||
|
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
|
||||||
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
|
||||||
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
|
||||||
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
|
||||||
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
|
||||||
|
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
|
||||||
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
|
||||||
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
|
||||||
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
|
||||||
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
|
||||||
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
"localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)),
|
||||||
|
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
options_templates.update(options_section(('ui', "Live previews"), {
|
options_templates.update(options_section(('ui', "Live previews"), {
|
||||||
@ -405,6 +432,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
|
|||||||
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
"eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
"ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
|
||||||
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
|
's_min_uncond': OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}),
|
||||||
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
|
||||||
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
|
||||||
@ -424,6 +452,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
|||||||
options_templates.update(options_section((None, "Hidden options"), {
|
options_templates.update(options_section((None, "Hidden options"), {
|
||||||
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
"disabled_extensions": OptionInfo([], "Disable these extensions"),
|
||||||
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
"disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}),
|
||||||
|
"restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"),
|
||||||
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
"sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@ -574,13 +603,37 @@ class Options:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
opts = Options()
|
opts = Options()
|
||||||
if os.path.exists(config_filename):
|
if os.path.exists(config_filename):
|
||||||
opts.load(config_filename)
|
opts.load(config_filename)
|
||||||
|
|
||||||
|
|
||||||
|
class Shared(sys.modules[__name__].__class__):
|
||||||
|
"""
|
||||||
|
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
|
||||||
|
at program startup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sd_model_val = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sd_model(self):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
return modules.sd_models.model_data.get_sd_model()
|
||||||
|
|
||||||
|
@sd_model.setter
|
||||||
|
def sd_model(self, value):
|
||||||
|
import modules.sd_models
|
||||||
|
|
||||||
|
modules.sd_models.model_data.set_sd_model(value)
|
||||||
|
|
||||||
|
|
||||||
|
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
|
||||||
|
sys.modules[__name__].__class__ = Shared
|
||||||
|
|
||||||
settings_components = None
|
settings_components = None
|
||||||
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
|
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
|
||||||
|
|
||||||
latent_upscale_default_mode = "Latent"
|
latent_upscale_default_mode = "Latent"
|
||||||
latent_upscale_modes = {
|
latent_upscale_modes = {
|
||||||
@ -594,12 +647,28 @@ latent_upscale_modes = {
|
|||||||
|
|
||||||
sd_upscalers = []
|
sd_upscalers = []
|
||||||
|
|
||||||
sd_model = None
|
|
||||||
|
|
||||||
clip_model = None
|
clip_model = None
|
||||||
|
|
||||||
progress_print_out = sys.stdout
|
progress_print_out = sys.stdout
|
||||||
|
|
||||||
|
gradio_theme = gr.themes.Base()
|
||||||
|
|
||||||
|
|
||||||
|
def reload_gradio_theme(theme_name=None):
|
||||||
|
global gradio_theme
|
||||||
|
if not theme_name:
|
||||||
|
theme_name = opts.gradio_theme
|
||||||
|
|
||||||
|
if theme_name == "Default":
|
||||||
|
gradio_theme = gr.themes.Default()
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
|
||||||
|
gradio_theme = gr.themes.Default()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TotalTQDM:
|
class TotalTQDM:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -72,16 +72,14 @@ class StyleDatabase:
|
|||||||
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])
|
||||||
|
|
||||||
def save_styles(self, path: str) -> None:
|
def save_styles(self, path: str) -> None:
|
||||||
# Write to temporary file first, so we don't nuke the file if something goes wrong
|
# Always keep a backup file around
|
||||||
fd, temp_path = tempfile.mkstemp(".csv")
|
if os.path.exists(path):
|
||||||
|
shutil.copy(path, path + ".bak")
|
||||||
|
|
||||||
|
fd = os.open(path, os.O_RDWR|os.O_CREAT)
|
||||||
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
|
||||||
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
|
||||||
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
|
||||||
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
|
||||||
writer.writeheader()
|
writer.writeheader()
|
||||||
writer.writerows(style._asdict() for k, style in self.styles.items())
|
writer.writerows(style._asdict() for k, style in self.styles.items())
|
||||||
|
|
||||||
# Always keep a backup file around
|
|
||||||
if os.path.exists(path):
|
|
||||||
shutil.move(path, path + ".bak")
|
|
||||||
shutil.move(temp_path, path)
|
|
||||||
|
@ -11,7 +11,7 @@ from modules.shared import opts, cmd_opts
|
|||||||
from modules.textual_inversion import autocrop
|
from modules.textual_inversion import autocrop
|
||||||
|
|
||||||
|
|
||||||
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||||
try:
|
try:
|
||||||
if process_caption:
|
if process_caption:
|
||||||
shared.interrogator.load()
|
shared.interrogator.load()
|
||||||
@ -19,7 +19,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height,
|
|||||||
if process_caption_deepbooru:
|
if process_caption_deepbooru:
|
||||||
deepbooru.model.start()
|
deepbooru.model.start()
|
||||||
|
|
||||||
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, thr
|
|||||||
return wh and center_crop(image, *wh)
|
return wh and center_crop(image, *wh)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_keep_original_size, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None):
|
||||||
width = process_width
|
width = process_width
|
||||||
height = process_height
|
height = process_height
|
||||||
src = os.path.abspath(process_src)
|
src = os.path.abspath(process_src)
|
||||||
@ -161,7 +161,9 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||||||
params.subindex = 0
|
params.subindex = 0
|
||||||
filename = os.path.join(src, imagefile)
|
filename = os.path.join(src, imagefile)
|
||||||
try:
|
try:
|
||||||
img = Image.open(filename).convert("RGB")
|
img = Image.open(filename)
|
||||||
|
img = ImageOps.exif_transpose(img)
|
||||||
|
img = img.convert("RGB")
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -223,6 +225,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
|
|||||||
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)")
|
||||||
process_default_resize = False
|
process_default_resize = False
|
||||||
|
|
||||||
|
if process_keep_original_size:
|
||||||
|
save_pic(img, index, params, existing_caption=existing_caption)
|
||||||
|
process_default_resize = False
|
||||||
|
|
||||||
if process_default_resize:
|
if process_default_resize:
|
||||||
img = images.resize_image(1, img, width, height)
|
img = images.resize_image(1, img, width, height)
|
||||||
save_pic(img, index, params, existing_caption=existing_caption)
|
save_pic(img, index, params, existing_caption=existing_caption)
|
||||||
|
@ -233,6 +233,12 @@ class EmbeddingDatabase:
|
|||||||
self.load_from_dir(embdir)
|
self.load_from_dir(embdir)
|
||||||
embdir.update()
|
embdir.update()
|
||||||
|
|
||||||
|
# re-sort word_embeddings because load_from_dir may not load in alphabetic order.
|
||||||
|
# using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it.
|
||||||
|
sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())}
|
||||||
|
self.word_embeddings.clear()
|
||||||
|
self.word_embeddings.update(sorted_word_embeddings)
|
||||||
|
|
||||||
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
|
||||||
if self.previously_displayed_embeddings != displayed_embeddings:
|
if self.previously_displayed_embeddings != displayed_embeddings:
|
||||||
self.previously_displayed_embeddings = displayed_embeddings
|
self.previously_displayed_embeddings = displayed_embeddings
|
||||||
|
160
modules/ui.py
160
modules/ui.py
@ -19,7 +19,7 @@ import numpy as np
|
|||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
|
||||||
|
|
||||||
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing
|
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing, progress
|
||||||
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
from modules.ui_components import FormRow, FormColumn, FormGroup, ToolButton, FormHTML
|
||||||
from modules.paths import script_path, data_path
|
from modules.paths import script_path, data_path
|
||||||
|
|
||||||
@ -81,6 +81,7 @@ apply_style_symbol = '\U0001f4cb' # 📋
|
|||||||
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️
|
||||||
extra_networks_symbol = '\U0001F3B4' # 🎴
|
extra_networks_symbol = '\U0001F3B4' # 🎴
|
||||||
switch_values_symbol = '\U000021C5' # ⇅
|
switch_values_symbol = '\U000021C5' # ⇅
|
||||||
|
restore_progress_symbol = '\U0001F300' # 🌀
|
||||||
|
|
||||||
|
|
||||||
def plaintext_to_html(text):
|
def plaintext_to_html(text):
|
||||||
@ -94,6 +95,9 @@ def send_gradio_gallery_to_image(x):
|
|||||||
|
|
||||||
def visit(x, func, path=""):
|
def visit(x, func, path=""):
|
||||||
if hasattr(x, 'children'):
|
if hasattr(x, 'children'):
|
||||||
|
if isinstance(x, gr.Tabs) and x.elem_id is not None:
|
||||||
|
# Tabs element can't have a label, have to use elem_id instead
|
||||||
|
func(f"{path}/Tabs@{x.elem_id}", x)
|
||||||
for c in x.children:
|
for c in x.children:
|
||||||
visit(c, func, path)
|
visit(c, func, path)
|
||||||
elif x.label is not None:
|
elif x.label is not None:
|
||||||
@ -127,6 +131,16 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz
|
|||||||
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
|
||||||
|
|
||||||
|
|
||||||
|
def resize_from_to_html(width, height, scale_by):
|
||||||
|
target_width = int(width * scale_by)
|
||||||
|
target_height = int(height * scale_by)
|
||||||
|
|
||||||
|
if not target_width or not target_height:
|
||||||
|
return "no image selected"
|
||||||
|
|
||||||
|
return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
|
||||||
|
|
||||||
|
|
||||||
def apply_styles(prompt, prompt_neg, styles):
|
def apply_styles(prompt, prompt_neg, styles):
|
||||||
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
|
||||||
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
|
prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
|
||||||
@ -171,8 +185,8 @@ def create_seed_inputs(target_interface):
|
|||||||
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
with FormRow(elem_id=target_interface + '_seed_row', variant="compact"):
|
||||||
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed')
|
||||||
seed.style(container=False)
|
seed.style(container=False)
|
||||||
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed')
|
random_seed = ToolButton(random_symbol, elem_id=target_interface + '_random_seed', label='Random seed')
|
||||||
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed')
|
reuse_seed = ToolButton(reuse_symbol, elem_id=target_interface + '_reuse_seed', label='Reuse seed')
|
||||||
|
|
||||||
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False)
|
||||||
|
|
||||||
@ -312,6 +326,7 @@ def create_toprow(is_img2img):
|
|||||||
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
|
||||||
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
|
||||||
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
|
||||||
|
restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
|
||||||
|
|
||||||
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
|
||||||
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
|
||||||
@ -329,7 +344,7 @@ def create_toprow(is_img2img):
|
|||||||
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
|
||||||
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
|
||||||
|
|
||||||
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button
|
return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
|
||||||
|
|
||||||
|
|
||||||
def setup_progressbar(*args, **kwargs):
|
def setup_progressbar(*args, **kwargs):
|
||||||
@ -446,7 +461,7 @@ def create_ui():
|
|||||||
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
|
||||||
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False)
|
txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
|
||||||
|
|
||||||
dummy_component = gr.Label(visible=False)
|
dummy_component = gr.Label(visible=False)
|
||||||
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
@ -468,7 +483,7 @@ def create_ui():
|
|||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height")
|
||||||
|
|
||||||
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Column(elem_id="txt2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn")
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn", label="Switch dims")
|
||||||
|
|
||||||
if opts.dimensions_and_batch_together:
|
if opts.dimensions_and_batch_together:
|
||||||
with gr.Column(elem_id="txt2img_column_batch"):
|
with gr.Column(elem_id="txt2img_column_batch"):
|
||||||
@ -578,6 +593,19 @@ def create_ui():
|
|||||||
|
|
||||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||||
|
|
||||||
|
restore_progress_button.click(
|
||||||
|
fn=progress.restore_progress,
|
||||||
|
_js="restoreProgressTxt2img",
|
||||||
|
inputs=[dummy_component],
|
||||||
|
outputs=[
|
||||||
|
txt2img_gallery,
|
||||||
|
generation_info,
|
||||||
|
html_info,
|
||||||
|
html_log,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
txt_prompt_img.change(
|
txt_prompt_img.change(
|
||||||
fn=modules.images.image_data,
|
fn=modules.images.image_data,
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -646,7 +674,7 @@ def create_ui():
|
|||||||
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
with gr.Blocks(analytics_enabled=False) as img2img_interface:
|
||||||
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True)
|
img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
|
||||||
|
|
||||||
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
|
||||||
|
|
||||||
@ -673,6 +701,8 @@ def create_ui():
|
|||||||
copy_image_buttons.append((button, name, elem))
|
copy_image_buttons.append((button, name, elem))
|
||||||
|
|
||||||
with gr.Tabs(elem_id="mode_img2img"):
|
with gr.Tabs(elem_id="mode_img2img"):
|
||||||
|
img2img_selected_tab = gr.State(0)
|
||||||
|
|
||||||
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:
|
||||||
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480)
|
||||||
add_copy_image_controls('img2img', init_img)
|
add_copy_image_controls('img2img', init_img)
|
||||||
@ -715,6 +745,12 @@ def create_ui():
|
|||||||
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir")
|
||||||
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir")
|
||||||
|
|
||||||
|
img2img_tabs = [tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]
|
||||||
|
img2img_image_inputs = [init_img, sketch, init_img_with_mask, inpaint_color_sketch]
|
||||||
|
|
||||||
|
for i, tab in enumerate(img2img_tabs):
|
||||||
|
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[img2img_selected_tab])
|
||||||
|
|
||||||
def copy_image(img):
|
def copy_image(img):
|
||||||
if isinstance(img, dict) and 'image' in img:
|
if isinstance(img, dict) and 'image' in img:
|
||||||
return img['image']
|
return img['image']
|
||||||
@ -744,11 +780,44 @@ def create_ui():
|
|||||||
elif category == "dimensions":
|
elif category == "dimensions":
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column(elem_id="img2img_column_size", scale=4):
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
selected_scale_tab = gr.State(value=0)
|
||||||
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
|
||||||
|
|
||||||
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
with gr.Tabs():
|
||||||
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
with gr.Tab(label="Resize to") as tab_scale_to:
|
||||||
|
with FormRow():
|
||||||
|
with gr.Column(elem_id="img2img_column_size", scale=4):
|
||||||
|
width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width")
|
||||||
|
height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height")
|
||||||
|
with gr.Column(elem_id="img2img_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
|
res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn")
|
||||||
|
|
||||||
|
with gr.Tab(label="Resize by") as tab_scale_by:
|
||||||
|
scale_by = gr.Slider(minimum=0.05, maximum=4.0, step=0.05, label="Scale", value=1.0, elem_id="img2img_scale")
|
||||||
|
|
||||||
|
with FormRow():
|
||||||
|
scale_by_html = FormHTML(resize_from_to_html(0, 0, 0.0), elem_id="img2img_scale_resolution_preview")
|
||||||
|
gr.Slider(label="Unused", elem_id="img2img_unused_scale_by_slider")
|
||||||
|
button_update_resize_to = gr.Button(visible=False, elem_id="img2img_update_resize_to")
|
||||||
|
|
||||||
|
on_change_args = dict(
|
||||||
|
fn=resize_from_to_html,
|
||||||
|
_js="currentImg2imgSourceResolution",
|
||||||
|
inputs=[dummy_component, dummy_component, scale_by],
|
||||||
|
outputs=scale_by_html,
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_by.release(**on_change_args)
|
||||||
|
button_update_resize_to.click(**on_change_args)
|
||||||
|
|
||||||
|
# the code below is meant to update the resolution label after the image in the image selection UI has changed.
|
||||||
|
# as it is now the event keeps firing continuously for inpaint edits, which ruins the page with constant requests.
|
||||||
|
# I assume this must be a gradio bug and for now we'll just do it for non-inpaint inputs.
|
||||||
|
for component in [init_img, sketch]:
|
||||||
|
component.change(fn=lambda: None, _js="updateImg2imgResizeToTextAfterChangingImage", inputs=[], outputs=[], show_progress=False)
|
||||||
|
|
||||||
|
tab_scale_to.select(fn=lambda: 0, inputs=[], outputs=[selected_scale_tab])
|
||||||
|
tab_scale_by.select(fn=lambda: 1, inputs=[], outputs=[selected_scale_tab])
|
||||||
|
|
||||||
if opts.dimensions_and_batch_together:
|
if opts.dimensions_and_batch_together:
|
||||||
with gr.Column(elem_id="img2img_column_batch"):
|
with gr.Column(elem_id="img2img_column_batch"):
|
||||||
@ -759,7 +828,7 @@ def create_ui():
|
|||||||
with FormGroup():
|
with FormGroup():
|
||||||
with FormRow():
|
with FormRow():
|
||||||
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
|
||||||
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
|
||||||
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")
|
||||||
|
|
||||||
elif category == "seed":
|
elif category == "seed":
|
||||||
@ -806,7 +875,7 @@ def create_ui():
|
|||||||
def select_img2img_tab(tab):
|
def select_img2img_tab(tab):
|
||||||
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3),
|
||||||
|
|
||||||
for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]):
|
for i, elem in enumerate(img2img_tabs):
|
||||||
elem.select(
|
elem.select(
|
||||||
fn=lambda tab=i: select_img2img_tab(tab),
|
fn=lambda tab=i: select_img2img_tab(tab),
|
||||||
inputs=[],
|
inputs=[],
|
||||||
@ -859,8 +928,10 @@ def create_ui():
|
|||||||
denoising_strength,
|
denoising_strength,
|
||||||
seed,
|
seed,
|
||||||
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox,
|
||||||
|
selected_scale_tab,
|
||||||
height,
|
height,
|
||||||
width,
|
width,
|
||||||
|
scale_by,
|
||||||
resize_mode,
|
resize_mode,
|
||||||
inpaint_full_res,
|
inpaint_full_res,
|
||||||
inpaint_full_res_padding,
|
inpaint_full_res_padding,
|
||||||
@ -898,6 +969,19 @@ def create_ui():
|
|||||||
submit.click(**img2img_args)
|
submit.click(**img2img_args)
|
||||||
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height], show_progress=False)
|
||||||
|
|
||||||
|
restore_progress_button.click(
|
||||||
|
fn=progress.restore_progress,
|
||||||
|
_js="restoreProgressImg2img",
|
||||||
|
inputs=[dummy_component],
|
||||||
|
outputs=[
|
||||||
|
img2img_gallery,
|
||||||
|
generation_info,
|
||||||
|
html_info,
|
||||||
|
html_log,
|
||||||
|
],
|
||||||
|
show_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
img2img_interrogate.click(
|
img2img_interrogate.click(
|
||||||
fn=lambda *args: process_interrogate(interrogate, *args),
|
fn=lambda *args: process_interrogate(interrogate, *args),
|
||||||
**interrogate_args,
|
**interrogate_args,
|
||||||
@ -1019,8 +1103,9 @@ def create_ui():
|
|||||||
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
|
||||||
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
|
||||||
|
save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
@ -1048,7 +1133,7 @@ def create_ui():
|
|||||||
with gr.Row(variant="compact").style(equal_height=False):
|
with gr.Row(variant="compact").style(equal_height=False):
|
||||||
with gr.Tabs(elem_id="train_tabs"):
|
with gr.Tabs(elem_id="train_tabs"):
|
||||||
|
|
||||||
with gr.Tab(label="Create embedding"):
|
with gr.Tab(label="Create embedding", id="create_embedding"):
|
||||||
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
|
new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name")
|
||||||
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
|
initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text")
|
||||||
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
|
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt")
|
||||||
@ -1061,7 +1146,7 @@ def create_ui():
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
|
create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding")
|
||||||
|
|
||||||
with gr.Tab(label="Create hypernetwork"):
|
with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
|
||||||
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")
|
||||||
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")
|
||||||
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")
|
||||||
@ -1079,7 +1164,7 @@ def create_ui():
|
|||||||
with gr.Column():
|
with gr.Column():
|
||||||
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork")
|
||||||
|
|
||||||
with gr.Tab(label="Preprocess images"):
|
with gr.Tab(label="Preprocess images", id="preprocess_images"):
|
||||||
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
process_src = gr.Textbox(label='Source directory', elem_id="train_process_src")
|
||||||
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst")
|
||||||
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width")
|
||||||
@ -1087,6 +1172,7 @@ def create_ui():
|
|||||||
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
process_keep_original_size = gr.Checkbox(label='Keep original size', elem_id="train_process_keep_original_size")
|
||||||
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
|
process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip")
|
||||||
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
|
process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split")
|
||||||
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
|
process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop")
|
||||||
@ -1146,7 +1232,7 @@ def create_ui():
|
|||||||
def get_textual_inversion_template_names():
|
def get_textual_inversion_template_names():
|
||||||
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
return sorted([x for x in textual_inversion.textual_inversion_templates])
|
||||||
|
|
||||||
with gr.Tab(label="Train"):
|
with gr.Tab(label="Train", id="train"):
|
||||||
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
|
||||||
with FormRow():
|
with FormRow():
|
||||||
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
|
||||||
@ -1204,7 +1290,7 @@ def create_ui():
|
|||||||
|
|
||||||
with gr.Column(elem_id='ti_gallery_container'):
|
with gr.Column(elem_id='ti_gallery_container'):
|
||||||
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
|
||||||
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
|
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(columns=4)
|
||||||
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
ti_progress = gr.HTML(elem_id="ti_progress", value="")
|
||||||
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
ti_outcome = gr.HTML(elem_id="ti_error", value="")
|
||||||
|
|
||||||
@ -1253,6 +1339,7 @@ def create_ui():
|
|||||||
process_width,
|
process_width,
|
||||||
process_height,
|
process_height,
|
||||||
preprocess_txt_action,
|
preprocess_txt_action,
|
||||||
|
process_keep_original_size,
|
||||||
process_flip,
|
process_flip,
|
||||||
process_split,
|
process_split,
|
||||||
process_caption,
|
process_caption,
|
||||||
@ -1479,7 +1566,7 @@ def create_ui():
|
|||||||
current_row.__exit__()
|
current_row.__exit__()
|
||||||
current_tab.__exit__()
|
current_tab.__exit__()
|
||||||
|
|
||||||
with gr.TabItem("Actions"):
|
with gr.TabItem("Actions", id="actions"):
|
||||||
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
|
||||||
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
|
||||||
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
|
||||||
@ -1487,7 +1574,7 @@ def create_ui():
|
|||||||
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
|
||||||
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
|
||||||
|
|
||||||
with gr.TabItem("Licenses"):
|
with gr.TabItem("Licenses", id="licenses"):
|
||||||
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
gr.HTML(shared.html("licenses.html"), elem_id="licenses")
|
||||||
|
|
||||||
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
|
||||||
@ -1565,7 +1652,7 @@ def create_ui():
|
|||||||
for _interface, label, _ifid in interfaces:
|
for _interface, label, _ifid in interfaces:
|
||||||
shared.tab_names.append(label)
|
shared.tab_names.append(label)
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False, title="Stable Diffusion") as demo:
|
with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo:
|
||||||
with gr.Row(elem_id="quicksettings", variant="compact"):
|
with gr.Row(elem_id="quicksettings", variant="compact"):
|
||||||
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])):
|
||||||
component = create_setting_component(k, is_quicksettings=True)
|
component = create_setting_component(k, is_quicksettings=True)
|
||||||
@ -1598,18 +1685,17 @@ def create_ui():
|
|||||||
component = component_dict[k]
|
component = component_dict[k]
|
||||||
info = opts.data_labels[k]
|
info = opts.data_labels[k]
|
||||||
|
|
||||||
component.change(
|
change_handler = component.release if hasattr(component, 'release') else component.change
|
||||||
|
change_handler(
|
||||||
fn=lambda value, k=k: run_settings_single(value, key=k),
|
fn=lambda value, k=k: run_settings_single(value, key=k),
|
||||||
inputs=[component],
|
inputs=[component],
|
||||||
outputs=[component, text_settings],
|
outputs=[component, text_settings],
|
||||||
show_progress=info.refresh is not None,
|
show_progress=info.refresh is not None,
|
||||||
)
|
)
|
||||||
|
|
||||||
text_settings.change(
|
update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
|
||||||
fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
|
text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
inputs=[],
|
demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
|
||||||
outputs=[image_cfg_scale],
|
|
||||||
)
|
|
||||||
|
|
||||||
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
|
||||||
button_set_checkpoint.click(
|
button_set_checkpoint.click(
|
||||||
@ -1658,6 +1744,7 @@ def create_ui():
|
|||||||
config_source,
|
config_source,
|
||||||
bake_in_vae,
|
bake_in_vae,
|
||||||
discard_weights,
|
discard_weights,
|
||||||
|
save_metadata,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
primary_model_name,
|
primary_model_name,
|
||||||
@ -1705,7 +1792,7 @@ def create_ui():
|
|||||||
if init_field is not None:
|
if init_field is not None:
|
||||||
init_field(saved_value)
|
init_field(saved_value)
|
||||||
|
|
||||||
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible:
|
if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
|
||||||
apply_field(x, 'visible')
|
apply_field(x, 'visible')
|
||||||
|
|
||||||
if type(x) == gr.Slider:
|
if type(x) == gr.Slider:
|
||||||
@ -1735,12 +1822,27 @@ def create_ui():
|
|||||||
|
|
||||||
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))
|
||||||
|
|
||||||
|
def check_tab_id(tab_id):
|
||||||
|
tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
|
||||||
|
if type(tab_id) == str:
|
||||||
|
tab_ids = [t.id for t in tab_items]
|
||||||
|
return tab_id in tab_ids
|
||||||
|
elif type(tab_id) == int:
|
||||||
|
return tab_id >= 0 and tab_id < len(tab_items)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if type(x) == gr.Tabs:
|
||||||
|
apply_field(x, 'selected', check_tab_id)
|
||||||
|
|
||||||
visit(txt2img_interface, loadsave, "txt2img")
|
visit(txt2img_interface, loadsave, "txt2img")
|
||||||
visit(img2img_interface, loadsave, "img2img")
|
visit(img2img_interface, loadsave, "img2img")
|
||||||
visit(extras_interface, loadsave, "extras")
|
visit(extras_interface, loadsave, "extras")
|
||||||
visit(modelmerger_interface, loadsave, "modelmerger")
|
visit(modelmerger_interface, loadsave, "modelmerger")
|
||||||
visit(train_interface, loadsave, "train")
|
visit(train_interface, loadsave, "train")
|
||||||
|
|
||||||
|
loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)
|
||||||
|
|
||||||
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
|
||||||
with open(ui_config_file, "w", encoding="utf8") as file:
|
with open(ui_config_file, "w", encoding="utf8") as file:
|
||||||
json.dump(ui_settings, file, indent=4)
|
json.dump(ui_settings, file, indent=4)
|
||||||
|
@ -125,7 +125,7 @@ Requested path was: {f}
|
|||||||
|
|
||||||
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
|
||||||
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
with gr.Group(elem_id=f"{tabname}_gallery_container"):
|
||||||
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4)
|
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4)
|
||||||
|
|
||||||
generation_info = None
|
generation_info = None
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -62,3 +62,13 @@ class DropdownMulti(FormComponent, gr.Dropdown):
|
|||||||
|
|
||||||
def get_block_name(self):
|
def get_block_name(self):
|
||||||
return "dropdown"
|
return "dropdown"
|
||||||
|
|
||||||
|
|
||||||
|
class DropdownEditable(FormComponent, gr.Dropdown):
|
||||||
|
"""Same as gr.Dropdown but allows editing value"""
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(allow_custom_value=True, **kwargs)
|
||||||
|
|
||||||
|
def get_block_name(self):
|
||||||
|
return "dropdown"
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
from datetime import datetime
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import git
|
import git
|
||||||
@ -11,10 +12,12 @@ import html
|
|||||||
import shutil
|
import shutil
|
||||||
import errno
|
import errno
|
||||||
|
|
||||||
from modules import extensions, shared, paths
|
from modules import extensions, shared, paths, config_states
|
||||||
|
from modules.paths_internal import config_states_dir
|
||||||
from modules.call_queue import wrap_gradio_gpu_call
|
from modules.call_queue import wrap_gradio_gpu_call
|
||||||
|
|
||||||
available_extensions = {"extensions": []}
|
available_extensions = {"extensions": []}
|
||||||
|
STYLE_PRIMARY = ' style="color: var(--primary-400)"'
|
||||||
|
|
||||||
|
|
||||||
def check_access():
|
def check_access():
|
||||||
@ -30,6 +33,9 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
|||||||
update = json.loads(update_list)
|
update = json.loads(update_list)
|
||||||
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
|
||||||
|
|
||||||
|
if update:
|
||||||
|
save_config_state("Backup (pre-update)")
|
||||||
|
|
||||||
update = set(update)
|
update = set(update)
|
||||||
|
|
||||||
for ext in extensions.extensions:
|
for ext in extensions.extensions:
|
||||||
@ -50,6 +56,46 @@ def apply_and_restart(disable_list, update_list, disable_all):
|
|||||||
shared.state.need_restart = True
|
shared.state.need_restart = True
|
||||||
|
|
||||||
|
|
||||||
|
def save_config_state(name):
|
||||||
|
current_config_state = config_states.get_config()
|
||||||
|
if not name:
|
||||||
|
name = "Config"
|
||||||
|
current_config_state["name"] = name
|
||||||
|
filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json")
|
||||||
|
print(f"Saving backup of webui/extension state to {filename}.")
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(current_config_state, f)
|
||||||
|
config_states.list_config_states()
|
||||||
|
new_value = next(iter(config_states.all_config_states.keys()), "Current")
|
||||||
|
new_choices = ["Current"] + list(config_states.all_config_states.keys())
|
||||||
|
return gr.Dropdown.update(value=new_value, choices=new_choices), f"<span>Saved current webui/extension state to \"{filename}\"</span>"
|
||||||
|
|
||||||
|
|
||||||
|
def restore_config_state(confirmed, config_state_name, restore_type):
|
||||||
|
if config_state_name == "Current":
|
||||||
|
return "<span>Select a config to restore from.</span>"
|
||||||
|
if not confirmed:
|
||||||
|
return "<span>Cancelled.</span>"
|
||||||
|
|
||||||
|
check_access()
|
||||||
|
|
||||||
|
config_state = config_states.all_config_states[config_state_name]
|
||||||
|
|
||||||
|
print(f"*** Restoring webui state from backup: {restore_type} ***")
|
||||||
|
|
||||||
|
if restore_type == "extensions" or restore_type == "both":
|
||||||
|
shared.opts.restore_config_state_file = config_state["filepath"]
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
if restore_type == "webui" or restore_type == "both":
|
||||||
|
config_states.restore_webui_config(config_state)
|
||||||
|
|
||||||
|
shared.state.interrupt()
|
||||||
|
shared.state.need_restart = True
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def check_updates(id_task, disable_list):
|
def check_updates(id_task, disable_list):
|
||||||
check_access()
|
check_access()
|
||||||
|
|
||||||
@ -76,6 +122,16 @@ def check_updates(id_task, disable_list):
|
|||||||
return extension_table(), ""
|
return extension_table(), ""
|
||||||
|
|
||||||
|
|
||||||
|
def make_commit_link(commit_hash, remote, text=None):
|
||||||
|
if text is None:
|
||||||
|
text = commit_hash[:8]
|
||||||
|
if remote.startswith("https://github.com/"):
|
||||||
|
href = os.path.join(remote, "commit", commit_hash)
|
||||||
|
return f'<a href="{href}" target="_blank">{text}</a>'
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
def extension_table():
|
def extension_table():
|
||||||
code = f"""<!-- {time.time()} -->
|
code = f"""<!-- {time.time()} -->
|
||||||
<table id="extensions">
|
<table id="extensions">
|
||||||
@ -102,13 +158,17 @@ def extension_table():
|
|||||||
|
|
||||||
style = ""
|
style = ""
|
||||||
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all":
|
||||||
style = ' style="color: var(--primary-400)"'
|
style = STYLE_PRIMARY
|
||||||
|
|
||||||
|
version_link = ext.version
|
||||||
|
if ext.commit_hash and ext.remote:
|
||||||
|
version_link = make_commit_link(ext.commit_hash, ext.remote, ext.version)
|
||||||
|
|
||||||
code += f"""
|
code += f"""
|
||||||
<tr>
|
<tr>
|
||||||
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
<td><label{style}><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
|
||||||
<td>{remote}</td>
|
<td>{remote}</td>
|
||||||
<td>{ext.version}</td>
|
<td>{version_link}</td>
|
||||||
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
|
||||||
</tr>
|
</tr>
|
||||||
"""
|
"""
|
||||||
@ -121,6 +181,133 @@ def extension_table():
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
||||||
|
def update_config_states_table(state_name):
|
||||||
|
if state_name == "Current":
|
||||||
|
config_state = config_states.get_config()
|
||||||
|
else:
|
||||||
|
config_state = config_states.all_config_states[state_name]
|
||||||
|
|
||||||
|
config_name = config_state.get("name", "Config")
|
||||||
|
created_date = time.asctime(time.gmtime(config_state["created_at"]))
|
||||||
|
filepath = config_state.get("filepath", "<unknown>")
|
||||||
|
|
||||||
|
code = f"""<!-- {time.time()} -->"""
|
||||||
|
|
||||||
|
webui_remote = config_state["webui"]["remote"] or ""
|
||||||
|
webui_branch = config_state["webui"]["branch"]
|
||||||
|
webui_commit_hash = config_state["webui"]["commit_hash"] or "<unknown>"
|
||||||
|
webui_commit_date = config_state["webui"]["commit_date"]
|
||||||
|
if webui_commit_date:
|
||||||
|
webui_commit_date = time.asctime(time.gmtime(webui_commit_date))
|
||||||
|
else:
|
||||||
|
webui_commit_date = "<unknown>"
|
||||||
|
|
||||||
|
remote = f"""<a href="{html.escape(webui_remote)}" target="_blank">{html.escape(webui_remote or '')}</a>"""
|
||||||
|
commit_link = make_commit_link(webui_commit_hash, webui_remote)
|
||||||
|
date_link = make_commit_link(webui_commit_hash, webui_remote, webui_commit_date)
|
||||||
|
|
||||||
|
current_webui = config_states.get_webui_config()
|
||||||
|
|
||||||
|
style_remote = ""
|
||||||
|
style_branch = ""
|
||||||
|
style_commit = ""
|
||||||
|
if current_webui["remote"] != webui_remote:
|
||||||
|
style_remote = STYLE_PRIMARY
|
||||||
|
if current_webui["branch"] != webui_branch:
|
||||||
|
style_branch = STYLE_PRIMARY
|
||||||
|
if current_webui["commit_hash"] != webui_commit_hash:
|
||||||
|
style_commit = STYLE_PRIMARY
|
||||||
|
|
||||||
|
code += f"""<h2>Config Backup: {config_name}</h2>
|
||||||
|
<div><b>Filepath:</b> {filepath}</div>
|
||||||
|
<div><b>Created at:</b> {created_date}</div>"""
|
||||||
|
|
||||||
|
code += f"""<h2>WebUI State</h2>
|
||||||
|
<table id="config_state_webui">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>URL</th>
|
||||||
|
<th>Branch</th>
|
||||||
|
<th>Commit</th>
|
||||||
|
<th>Date</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
<tr>
|
||||||
|
<td><label{style_remote}>{remote}</label></td>
|
||||||
|
<td><label{style_branch}>{webui_branch}</label></td>
|
||||||
|
<td><label{style_commit}>{commit_link}</label></td>
|
||||||
|
<td><label{style_commit}>{date_link}</label></td>
|
||||||
|
</tr>
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
code += """<h2>Extension State</h2>
|
||||||
|
<table id="config_state_extensions">
|
||||||
|
<thead>
|
||||||
|
<tr>
|
||||||
|
<th>Extension</th>
|
||||||
|
<th>URL</th>
|
||||||
|
<th>Branch</th>
|
||||||
|
<th>Commit</th>
|
||||||
|
<th>Date</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
"""
|
||||||
|
|
||||||
|
ext_map = {ext.name: ext for ext in extensions.extensions}
|
||||||
|
|
||||||
|
for ext_name, ext_conf in config_state["extensions"].items():
|
||||||
|
ext_remote = ext_conf["remote"] or ""
|
||||||
|
ext_branch = ext_conf["branch"] or "<unknown>"
|
||||||
|
ext_enabled = ext_conf["enabled"]
|
||||||
|
ext_commit_hash = ext_conf["commit_hash"] or "<unknown>"
|
||||||
|
ext_commit_date = ext_conf["commit_date"]
|
||||||
|
if ext_commit_date:
|
||||||
|
ext_commit_date = time.asctime(time.gmtime(ext_commit_date))
|
||||||
|
else:
|
||||||
|
ext_commit_date = "<unknown>"
|
||||||
|
|
||||||
|
remote = f"""<a href="{html.escape(ext_remote)}" target="_blank">{html.escape(ext_remote or '')}</a>"""
|
||||||
|
commit_link = make_commit_link(ext_commit_hash, ext_remote)
|
||||||
|
date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date)
|
||||||
|
|
||||||
|
style_enabled = ""
|
||||||
|
style_remote = ""
|
||||||
|
style_branch = ""
|
||||||
|
style_commit = ""
|
||||||
|
if ext_name in ext_map:
|
||||||
|
current_ext = ext_map[ext_name]
|
||||||
|
current_ext.read_info_from_repo()
|
||||||
|
if current_ext.enabled != ext_enabled:
|
||||||
|
style_enabled = STYLE_PRIMARY
|
||||||
|
if current_ext.remote != ext_remote:
|
||||||
|
style_remote = STYLE_PRIMARY
|
||||||
|
if current_ext.branch != ext_branch:
|
||||||
|
style_branch = STYLE_PRIMARY
|
||||||
|
if current_ext.commit_hash != ext_commit_hash:
|
||||||
|
style_commit = STYLE_PRIMARY
|
||||||
|
|
||||||
|
code += f"""
|
||||||
|
<tr>
|
||||||
|
<td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td>
|
||||||
|
<td><label{style_remote}>{remote}</label></td>
|
||||||
|
<td><label{style_branch}>{ext_branch}</label></td>
|
||||||
|
<td><label{style_commit}>{commit_link}</label></td>
|
||||||
|
<td><label{style_commit}>{date_link}</label></td>
|
||||||
|
</tr>
|
||||||
|
"""
|
||||||
|
|
||||||
|
code += """
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
"""
|
||||||
|
|
||||||
|
return code
|
||||||
|
|
||||||
|
|
||||||
def normalize_git_url(url):
|
def normalize_git_url(url):
|
||||||
if url is None:
|
if url is None:
|
||||||
return ""
|
return ""
|
||||||
@ -129,7 +316,7 @@ def normalize_git_url(url):
|
|||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
def install_extension_from_url(dirname, url):
|
def install_extension_from_url(dirname, url, branch_name=None):
|
||||||
check_access()
|
check_access()
|
||||||
|
|
||||||
assert url, 'No URL specified'
|
assert url, 'No URL specified'
|
||||||
@ -150,10 +337,17 @@ def install_extension_from_url(dirname, url):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
shutil.rmtree(tmpdir, True)
|
shutil.rmtree(tmpdir, True)
|
||||||
with git.Repo.clone_from(url, tmpdir) as repo:
|
if not branch_name:
|
||||||
repo.remote().fetch()
|
# if no branch is specified, use the default branch
|
||||||
for submodule in repo.submodules:
|
with git.Repo.clone_from(url, tmpdir) as repo:
|
||||||
submodule.update()
|
repo.remote().fetch()
|
||||||
|
for submodule in repo.submodules:
|
||||||
|
submodule.update()
|
||||||
|
else:
|
||||||
|
with git.Repo.clone_from(url, tmpdir, branch=branch_name) as repo:
|
||||||
|
repo.remote().fetch()
|
||||||
|
for submodule in repo.submodules:
|
||||||
|
submodule.update()
|
||||||
try:
|
try:
|
||||||
os.rename(tmpdir, target_dir)
|
os.rename(tmpdir, target_dir)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
@ -292,9 +486,11 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text="
|
|||||||
def create_ui():
|
def create_ui():
|
||||||
import modules.ui
|
import modules.ui
|
||||||
|
|
||||||
|
config_states.list_config_states()
|
||||||
|
|
||||||
with gr.Blocks(analytics_enabled=False) as ui:
|
with gr.Blocks(analytics_enabled=False) as ui:
|
||||||
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
with gr.Tabs(elem_id="tabs_extensions") as tabs:
|
||||||
with gr.TabItem("Installed"):
|
with gr.TabItem("Installed", id="installed"):
|
||||||
|
|
||||||
with gr.Row(elem_id="extensions_installed_top"):
|
with gr.Row(elem_id="extensions_installed_top"):
|
||||||
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
apply = gr.Button(value="Apply and restart UI", variant="primary")
|
||||||
@ -327,7 +523,7 @@ def create_ui():
|
|||||||
outputs=[extensions_table, info],
|
outputs=[extensions_table, info],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.TabItem("Available"):
|
with gr.TabItem("Available", id="available"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
|
||||||
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
|
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json", label="Extension index URL").style(container=False)
|
||||||
@ -374,16 +570,41 @@ def create_ui():
|
|||||||
outputs=[available_extensions_table, install_result]
|
outputs=[available_extensions_table, install_result]
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.TabItem("Install from URL"):
|
with gr.TabItem("Install from URL", id="install_from_url"):
|
||||||
install_url = gr.Text(label="URL for extension's git repository")
|
install_url = gr.Text(label="URL for extension's git repository")
|
||||||
|
install_branch = gr.Text(label="Specific branch name", placeholder="Leave empty for default main branch")
|
||||||
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
|
||||||
install_button = gr.Button(value="Install", variant="primary")
|
install_button = gr.Button(value="Install", variant="primary")
|
||||||
install_result = gr.HTML(elem_id="extension_install_result")
|
install_result = gr.HTML(elem_id="extension_install_result")
|
||||||
|
|
||||||
install_button.click(
|
install_button.click(
|
||||||
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
|
||||||
inputs=[install_dirname, install_url],
|
inputs=[install_dirname, install_url, install_branch],
|
||||||
outputs=[extensions_table, install_result],
|
outputs=[extensions_table, install_result],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Backup/Restore"):
|
||||||
|
with gr.Row(elem_id="extensions_backup_top_row"):
|
||||||
|
config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys()))
|
||||||
|
modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states")
|
||||||
|
config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type")
|
||||||
|
config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore")
|
||||||
|
with gr.Row(elem_id="extensions_backup_top_row2"):
|
||||||
|
config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False)
|
||||||
|
config_save_button = gr.Button(value="Save Current Config")
|
||||||
|
|
||||||
|
config_states_info = gr.HTML("")
|
||||||
|
config_states_table = gr.HTML(lambda: update_config_states_table("Current"))
|
||||||
|
|
||||||
|
config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info])
|
||||||
|
|
||||||
|
dummy_component = gr.Label(visible=False)
|
||||||
|
config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info])
|
||||||
|
|
||||||
|
config_states_list.change(
|
||||||
|
fn=update_config_states_table,
|
||||||
|
inputs=[config_states_list],
|
||||||
|
outputs=[config_states_table],
|
||||||
|
)
|
||||||
|
|
||||||
return ui
|
return ui
|
||||||
|
@ -241,7 +241,7 @@ def create_ui(container, button, tabname):
|
|||||||
|
|
||||||
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
|
||||||
for page in ui.stored_extra_pages:
|
for page in ui.stored_extra_pages:
|
||||||
with gr.Tab(page.title):
|
with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
|
||||||
|
|
||||||
page_elem = gr.HTML(page.create_html(ui.tabname))
|
page_elem = gr.HTML(page.create_html(ui.tabname))
|
||||||
ui.pages.append(page_elem)
|
ui.pages.append(page_elem)
|
||||||
|
@ -9,13 +9,13 @@ def create_ui():
|
|||||||
with gr.Row().style(equal_height=False, variant='compact'):
|
with gr.Row().style(equal_height=False, variant='compact'):
|
||||||
with gr.Column(variant='compact'):
|
with gr.Column(variant='compact'):
|
||||||
with gr.Tabs(elem_id="mode_extras"):
|
with gr.Tabs(elem_id="mode_extras"):
|
||||||
with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single:
|
with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:
|
||||||
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image")
|
||||||
|
|
||||||
with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch:
|
with gr.TabItem('Batch Process', id="batch_process", elem_id="extras_batch_process_tab") as tab_batch:
|
||||||
image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch")
|
image_batch = gr.Files(label="Batch Process", interactive=True, elem_id="extras_image_batch")
|
||||||
|
|
||||||
with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
with gr.TabItem('Batch from Directory', id="batch_from_directory", elem_id="extras_batch_directory_tab") as tab_batch_dir:
|
||||||
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir")
|
||||||
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir")
|
||||||
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results")
|
||||||
|
@ -36,7 +36,7 @@ def save_pil_to_file(pil_image, dir=None):
|
|||||||
if already_saved_as and os.path.isfile(already_saved_as):
|
if already_saved_as and os.path.isfile(already_saved_as):
|
||||||
register_tmp_file(shared.demo, already_saved_as)
|
register_tmp_file(shared.demo, already_saved_as)
|
||||||
|
|
||||||
file_obj = Savedfile(already_saved_as)
|
file_obj = Savedfile(f"{already_saved_as}?{os.path.getmtime(already_saved_as)}")
|
||||||
return file_obj
|
return file_obj
|
||||||
|
|
||||||
if shared.opts.temp_dir != "":
|
if shared.opts.temp_dir != "":
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
|
astunparse
|
||||||
blendmodes
|
blendmodes
|
||||||
accelerate
|
accelerate
|
||||||
basicsr
|
basicsr
|
||||||
fonts
|
fonts
|
||||||
font-roboto
|
font-roboto
|
||||||
gfpgan
|
gfpgan
|
||||||
gradio==3.23
|
gradio==3.28.1
|
||||||
invisible-watermark
|
|
||||||
numpy
|
numpy
|
||||||
omegaconf
|
omegaconf
|
||||||
opencv-contrib-python
|
opencv-contrib-python
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
blendmodes==2022
|
blendmodes==2022
|
||||||
transformers==4.25.1
|
transformers==4.25.1
|
||||||
accelerate==0.12.0
|
accelerate==0.18.0
|
||||||
basicsr==1.4.2
|
basicsr==1.4.2
|
||||||
gfpgan==1.3.8
|
gfpgan==1.3.8
|
||||||
gradio==3.23
|
gradio==3.28.1
|
||||||
numpy==1.23.3
|
numpy==1.23.5
|
||||||
Pillow==9.4.0
|
Pillow==9.4.0
|
||||||
realesrgan==0.3.0
|
realesrgan==0.3.0
|
||||||
torch
|
torch
|
||||||
@ -25,6 +25,6 @@ lark==1.1.2
|
|||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
GitPython==3.1.30
|
GitPython==3.1.30
|
||||||
torchsde==0.2.5
|
torchsde==0.2.5
|
||||||
safetensors==0.3.0
|
safetensors==0.3.1
|
||||||
httpcore<=0.15
|
httpcore<=0.15
|
||||||
fastapi==0.94.0
|
fastapi==0.94.0
|
||||||
|
@ -7,7 +7,7 @@ function gradioApp() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function get_uiCurrentTab() {
|
function get_uiCurrentTab() {
|
||||||
return gradioApp().querySelector('#tabs button:not(.border-transparent)')
|
return gradioApp().querySelector('#tabs button.selected')
|
||||||
}
|
}
|
||||||
|
|
||||||
function get_uiCurrentTabContent() {
|
function get_uiCurrentTabContent() {
|
||||||
|
@ -1,9 +1,40 @@
|
|||||||
import modules.scripts as scripts
|
import modules.scripts as scripts
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import ast
|
||||||
|
import copy
|
||||||
|
|
||||||
from modules.processing import Processed
|
from modules.processing import Processed
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
|
|
||||||
|
|
||||||
|
def convertExpr2Expression(expr):
|
||||||
|
expr.lineno = 0
|
||||||
|
expr.col_offset = 0
|
||||||
|
result = ast.Expression(expr.value, lineno=0, col_offset = 0)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def exec_with_return(code, module):
|
||||||
|
"""
|
||||||
|
like exec() but can return values
|
||||||
|
https://stackoverflow.com/a/52361938/5862977
|
||||||
|
"""
|
||||||
|
code_ast = ast.parse(code)
|
||||||
|
|
||||||
|
init_ast = copy.deepcopy(code_ast)
|
||||||
|
init_ast.body = code_ast.body[:-1]
|
||||||
|
|
||||||
|
last_ast = copy.deepcopy(code_ast)
|
||||||
|
last_ast.body = code_ast.body[-1:]
|
||||||
|
|
||||||
|
exec(compile(init_ast, "<ast>", "exec"), module.__dict__)
|
||||||
|
if type(last_ast.body[0]) == ast.Expr:
|
||||||
|
return eval(compile(convertExpr2Expression(last_ast.body[0]), "<ast>", "eval"), module.__dict__)
|
||||||
|
else:
|
||||||
|
exec(compile(last_ast, "<ast>", "exec"), module.__dict__)
|
||||||
|
|
||||||
|
|
||||||
class Script(scripts.Script):
|
class Script(scripts.Script):
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
@ -13,12 +44,23 @@ class Script(scripts.Script):
|
|||||||
return cmd_opts.allow_code
|
return cmd_opts.allow_code
|
||||||
|
|
||||||
def ui(self, is_img2img):
|
def ui(self, is_img2img):
|
||||||
code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code"))
|
example = """from modules.processing import process_images
|
||||||
|
|
||||||
return [code]
|
p.width = 768
|
||||||
|
p.height = 768
|
||||||
|
p.batch_size = 2
|
||||||
|
p.steps = 10
|
||||||
|
|
||||||
|
return process_images(p)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def run(self, p, code):
|
code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code"))
|
||||||
|
indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level"))
|
||||||
|
|
||||||
|
return [code, indent_level]
|
||||||
|
|
||||||
|
def run(self, p, code, indent_level):
|
||||||
assert cmd_opts.allow_code, '--allow-code option must be enabled'
|
assert cmd_opts.allow_code, '--allow-code option must be enabled'
|
||||||
|
|
||||||
display_result_data = [[], -1, ""]
|
display_result_data = [[], -1, ""]
|
||||||
@ -29,13 +71,20 @@ class Script(scripts.Script):
|
|||||||
display_result_data[2] = i
|
display_result_data[2] = i
|
||||||
|
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
compiled = compile(code, '', 'exec')
|
|
||||||
module = ModuleType("testmodule")
|
module = ModuleType("testmodule")
|
||||||
module.__dict__.update(globals())
|
module.__dict__.update(globals())
|
||||||
module.p = p
|
module.p = p
|
||||||
module.display = display
|
module.display = display
|
||||||
exec(compiled, module.__dict__)
|
|
||||||
|
indent = " " * indent_level
|
||||||
|
indented = code.replace('\n', '\n' + indent)
|
||||||
|
body = f"""def __webuitemp__():
|
||||||
|
{indent}{indented}
|
||||||
|
__webuitemp__()"""
|
||||||
|
|
||||||
|
result = exec_with_return(body, module)
|
||||||
|
|
||||||
|
if isinstance(result, Processed):
|
||||||
|
return result
|
||||||
|
|
||||||
return Processed(p, *display_result_data)
|
return Processed(p, *display_result_data)
|
||||||
|
|
||||||
|
|
@ -275,7 +275,7 @@ class Script(scripts.Script):
|
|||||||
|
|
||||||
if opts.samples_save:
|
if opts.samples_save:
|
||||||
for img in all_processed_images:
|
for img in all_processed_images:
|
||||||
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
|
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.samples_format, info=res.info, p=p)
|
||||||
|
|
||||||
if opts.grid_save and not unwanted_grid_because_of_img_count:
|
if opts.grid_save and not unwanted_grid_because_of_img_count:
|
||||||
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
|
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
|
||||||
|
@ -138,7 +138,7 @@ class Script(scripts.Script):
|
|||||||
combined_image = images.combine_grid(grid)
|
combined_image = images.combine_grid(grid)
|
||||||
|
|
||||||
if opts.samples_save:
|
if opts.samples_save:
|
||||||
images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info, p=p)
|
images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.samples_format, info=initial_info, p=p)
|
||||||
|
|
||||||
processed = Processed(p, [combined_image], initial_seed, initial_info)
|
processed = Processed(p, [combined_image], initial_seed, initial_info)
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import numpy as np
|
|||||||
from modules import scripts_postprocessing, shared
|
from modules import scripts_postprocessing, shared
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from modules.ui_components import FormRow
|
from modules.ui_components import FormRow, ToolButton
|
||||||
|
from modules.ui import switch_values_symbol
|
||||||
|
|
||||||
upscale_cache = {}
|
upscale_cache = {}
|
||||||
|
|
||||||
@ -25,9 +25,12 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
|
|
||||||
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
|
with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to:
|
||||||
with FormRow():
|
with FormRow():
|
||||||
upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w")
|
with gr.Column(elem_id="upscaling_column_size", scale=4):
|
||||||
upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h")
|
upscaling_resize_w = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="extras_upscaling_resize_w")
|
||||||
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
upscaling_resize_h = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="extras_upscaling_resize_h")
|
||||||
|
with gr.Column(elem_id="upscaling_dimensions_row", scale=1, elem_classes="dimensions-tools"):
|
||||||
|
upscaling_res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="upscaling_res_switch_btn")
|
||||||
|
upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop")
|
||||||
|
|
||||||
with FormRow():
|
with FormRow():
|
||||||
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
@ -36,6 +39,7 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing):
|
|||||||
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name)
|
||||||
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility")
|
||||||
|
|
||||||
|
upscaling_res_switch_btn.click(lambda w, h: (h, w), inputs=[upscaling_resize_w, upscaling_resize_h], outputs=[upscaling_resize_w, upscaling_resize_h], show_progress=False)
|
||||||
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab])
|
||||||
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab])
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ def apply_checkpoint(p, x, xs):
|
|||||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
info = modules.sd_models.get_closet_checkpoint_match(x)
|
||||||
if info is None:
|
if info is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
p.override_settings['sd_model_checkpoint'] = info.hash
|
||||||
|
|
||||||
|
|
||||||
def confirm_checkpoints(p, xs):
|
def confirm_checkpoints(p, xs):
|
||||||
@ -211,7 +211,8 @@ axis_options = [
|
|||||||
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
|
||||||
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
|
||||||
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]),
|
||||||
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
|
AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),
|
||||||
|
AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),
|
||||||
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
AxisOption("Sigma Churn", float, apply_field("s_churn")),
|
||||||
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
AxisOption("Sigma min", float, apply_field("s_tmin")),
|
||||||
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
AxisOption("Sigma max", float, apply_field("s_tmax")),
|
||||||
@ -374,16 +375,19 @@ class Script(scripts.Script):
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
|
||||||
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
|
||||||
|
x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
|
||||||
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
|
||||||
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
|
||||||
|
y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
|
||||||
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
|
||||||
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
|
||||||
|
z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
|
||||||
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)
|
||||||
|
|
||||||
with gr.Row(variant="compact", elem_id="axis_options"):
|
with gr.Row(variant="compact", elem_id="axis_options"):
|
||||||
@ -401,54 +405,74 @@ class Script(scripts.Script):
|
|||||||
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
|
||||||
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button")
|
||||||
|
|
||||||
def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values):
|
def swap_axes(axis1_type, axis1_values, axis1_values_dropdown, axis2_type, axis2_values, axis2_values_dropdown):
|
||||||
return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values
|
return self.current_axis_options[axis2_type].label, axis2_values, axis2_values_dropdown, self.current_axis_options[axis1_type].label, axis1_values, axis1_values_dropdown
|
||||||
|
|
||||||
xy_swap_args = [x_type, x_values, y_type, y_values]
|
xy_swap_args = [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown]
|
||||||
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
|
swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args)
|
||||||
yz_swap_args = [y_type, y_values, z_type, z_values]
|
yz_swap_args = [y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown]
|
||||||
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
|
swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args)
|
||||||
xz_swap_args = [x_type, x_values, z_type, z_values]
|
xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown]
|
||||||
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args)
|
||||||
|
|
||||||
def fill(x_type):
|
def fill(x_type):
|
||||||
axis = self.current_axis_options[x_type]
|
axis = self.current_axis_options[x_type]
|
||||||
return ", ".join(axis.choices()) if axis.choices else gr.update()
|
return axis.choices() if axis.choices else gr.update()
|
||||||
|
|
||||||
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
|
fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
|
||||||
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
|
fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
|
||||||
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
|
fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])
|
||||||
|
|
||||||
def select_axis(x_type):
|
def select_axis(axis_type,axis_values_dropdown):
|
||||||
return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
|
choices = self.current_axis_options[axis_type].choices
|
||||||
|
has_choices = choices is not None
|
||||||
|
current_values = axis_values_dropdown
|
||||||
|
if has_choices:
|
||||||
|
choices = choices()
|
||||||
|
if isinstance(current_values,str):
|
||||||
|
current_values = current_values.split(",")
|
||||||
|
current_values = list(filter(lambda x: x in choices, current_values))
|
||||||
|
return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices if has_choices else None,visible=has_choices,value=current_values)
|
||||||
|
|
||||||
x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
|
x_type.change(fn=select_axis, inputs=[x_type,x_values_dropdown], outputs=[fill_x_button,x_values,x_values_dropdown])
|
||||||
y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
|
y_type.change(fn=select_axis, inputs=[y_type,y_values_dropdown], outputs=[fill_y_button,y_values,y_values_dropdown])
|
||||||
z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
|
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
|
||||||
|
|
||||||
|
def get_dropdown_update_from_params(axis,params):
|
||||||
|
val_key = axis + " Values"
|
||||||
|
vals = params.get(val_key,"")
|
||||||
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||||
|
return gr.update(value = valslist)
|
||||||
|
|
||||||
self.infotext_fields = (
|
self.infotext_fields = (
|
||||||
(x_type, "X Type"),
|
(x_type, "X Type"),
|
||||||
(x_values, "X Values"),
|
(x_values, "X Values"),
|
||||||
|
(x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)),
|
||||||
(y_type, "Y Type"),
|
(y_type, "Y Type"),
|
||||||
(y_values, "Y Values"),
|
(y_values, "Y Values"),
|
||||||
|
(y_values_dropdown, lambda params:get_dropdown_update_from_params("Y",params)),
|
||||||
(z_type, "Z Type"),
|
(z_type, "Z Type"),
|
||||||
(z_values, "Z Values"),
|
(z_values, "Z Values"),
|
||||||
|
(z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
|
||||||
|
|
||||||
def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
|
||||||
if not no_fixed_seeds:
|
if not no_fixed_seeds:
|
||||||
modules.processing.fix_seed(p)
|
modules.processing.fix_seed(p)
|
||||||
|
|
||||||
if not opts.return_grid:
|
if not opts.return_grid:
|
||||||
p.batch_size = 1
|
p.batch_size = 1
|
||||||
|
|
||||||
def process_axis(opt, vals):
|
def process_axis(opt, vals, vals_dropdown):
|
||||||
if opt.label == 'Nothing':
|
if opt.label == 'Nothing':
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
if opt.choices is not None:
|
||||||
|
valslist = vals_dropdown
|
||||||
|
else:
|
||||||
|
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
|
||||||
|
|
||||||
if opt.type == int:
|
if opt.type == int:
|
||||||
valslist_ext = []
|
valslist_ext = []
|
||||||
@ -506,13 +530,19 @@ class Script(scripts.Script):
|
|||||||
return valslist
|
return valslist
|
||||||
|
|
||||||
x_opt = self.current_axis_options[x_type]
|
x_opt = self.current_axis_options[x_type]
|
||||||
xs = process_axis(x_opt, x_values)
|
if x_opt.choices is not None:
|
||||||
|
x_values = ",".join(x_values_dropdown)
|
||||||
|
xs = process_axis(x_opt, x_values, x_values_dropdown)
|
||||||
|
|
||||||
y_opt = self.current_axis_options[y_type]
|
y_opt = self.current_axis_options[y_type]
|
||||||
ys = process_axis(y_opt, y_values)
|
if y_opt.choices is not None:
|
||||||
|
y_values = ",".join(y_values_dropdown)
|
||||||
|
ys = process_axis(y_opt, y_values, y_values_dropdown)
|
||||||
|
|
||||||
z_opt = self.current_axis_options[z_type]
|
z_opt = self.current_axis_options[z_type]
|
||||||
zs = process_axis(z_opt, z_values)
|
if z_opt.choices is not None:
|
||||||
|
z_values = ",".join(z_values_dropdown)
|
||||||
|
zs = process_axis(z_opt, z_values, z_values_dropdown)
|
||||||
|
|
||||||
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
# this could be moved to common code, but unlikely to be ever triggered anywhere else
|
||||||
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes
|
||||||
|
33
style.css
33
style.css
@ -246,7 +246,7 @@ button.custom-button{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#txt2img_gallery img, #img2img_gallery img{
|
#txt2img_gallery img, #img2img_gallery img, #extras_gallery img{
|
||||||
object-fit: scale-down;
|
object-fit: scale-down;
|
||||||
}
|
}
|
||||||
#txt2img_actions_column, #img2img_actions_column {
|
#txt2img_actions_column, #img2img_actions_column {
|
||||||
@ -293,7 +293,12 @@ button.custom-button{
|
|||||||
margin-left: -0.75em
|
margin-left: -0.75em
|
||||||
}
|
}
|
||||||
|
|
||||||
#txtimg_hr_finalres .resolution{
|
#img2img_scale_resolution_preview.block{
|
||||||
|
display: flex;
|
||||||
|
align-items: end;
|
||||||
|
}
|
||||||
|
|
||||||
|
#txtimg_hr_finalres .resolution, #img2img_scale_resolution_preview .resolution{
|
||||||
font-weight: bold;
|
font-weight: bold;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -312,6 +317,10 @@ div.dimensions-tools{
|
|||||||
align-content: center;
|
align-content: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div#extras_scale_to_tab div.form{
|
||||||
|
flex-direction: row;
|
||||||
|
}
|
||||||
|
|
||||||
#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
|
#mode_img2img .gradio-image > div.fixed-height, #mode_img2img .gradio-image > div.fixed-height img{
|
||||||
height: 480px !important;
|
height: 480px !important;
|
||||||
max-height: 480px !important;
|
max-height: 480px !important;
|
||||||
@ -333,6 +342,18 @@ div.dimensions-tools{
|
|||||||
overflow-wrap: break-word;
|
overflow-wrap: break-word;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#img2img_column_batch{
|
||||||
|
align-self: end;
|
||||||
|
margin-bottom: 0.9em;
|
||||||
|
}
|
||||||
|
|
||||||
|
#img2img_unused_scale_by_slider{
|
||||||
|
visibility: hidden;
|
||||||
|
width: 0.5em;
|
||||||
|
max-width: 0.5em;
|
||||||
|
min-width: 0.5em;
|
||||||
|
}
|
||||||
|
|
||||||
/* settings */
|
/* settings */
|
||||||
#quicksettings {
|
#quicksettings {
|
||||||
width: fit-content;
|
width: fit-content;
|
||||||
@ -513,6 +534,8 @@ div.dimensions-tools{
|
|||||||
#lightboxModal > img.modalImageFullscreen{
|
#lightboxModal > img.modalImageFullscreen{
|
||||||
object-fit: contain;
|
object-fit: contain;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
|
width: 100%;
|
||||||
|
min-height: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.modalPrev,
|
.modalPrev,
|
||||||
@ -642,6 +665,12 @@ footer {
|
|||||||
|
|
||||||
/* extra networks UI */
|
/* extra networks UI */
|
||||||
|
|
||||||
|
.extra-network-cards{
|
||||||
|
height: 725px;
|
||||||
|
overflow: scroll;
|
||||||
|
resize: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
.extra-networks > div > [id *= '_extra_']{
|
.extra-networks > div > [id *= '_extra_']{
|
||||||
margin: 0.3em;
|
margin: 0.3em;
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ fi
|
|||||||
|
|
||||||
export install_dir="$HOME"
|
export install_dir="$HOME"
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate"
|
||||||
export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1"
|
export TORCH_COMMAND="pip install torch torchvision"
|
||||||
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git"
|
||||||
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71"
|
||||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
|
@ -43,4 +43,7 @@
|
|||||||
# Uncomment to enable accelerated launch
|
# Uncomment to enable accelerated launch
|
||||||
#export ACCELERATE="True"
|
#export ACCELERATE="True"
|
||||||
|
|
||||||
|
# Uncomment to disable TCMalloc
|
||||||
|
#export NO_TCMALLOC="True"
|
||||||
|
|
||||||
###########################################
|
###########################################
|
||||||
|
117
webui.py
117
webui.py
@ -5,6 +5,9 @@ import importlib
|
|||||||
import signal
|
import signal
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
import json
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.middleware.gzip import GZipMiddleware
|
from fastapi.middleware.gzip import GZipMiddleware
|
||||||
@ -20,6 +23,9 @@ startup_timer = timer.Timer()
|
|||||||
import torch
|
import torch
|
||||||
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them
|
||||||
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
|
||||||
|
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
|
||||||
|
|
||||||
|
|
||||||
startup_timer.record("import torch")
|
startup_timer.record("import torch")
|
||||||
|
|
||||||
import gradio
|
import gradio
|
||||||
@ -37,7 +43,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
|
|||||||
torch.__long_version__ = torch.__version__
|
torch.__long_version__ = torch.__version__
|
||||||
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
|
||||||
|
|
||||||
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks
|
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
|
||||||
import modules.codeformer_model as codeformer
|
import modules.codeformer_model as codeformer
|
||||||
import modules.face_restoration
|
import modules.face_restoration
|
||||||
import modules.gfpgan_model as gfpgan
|
import modules.gfpgan_model as gfpgan
|
||||||
@ -67,11 +73,51 @@ else:
|
|||||||
server_name = "0.0.0.0" if cmd_opts.listen else None
|
server_name = "0.0.0.0" if cmd_opts.listen else None
|
||||||
|
|
||||||
|
|
||||||
|
def fix_asyncio_event_loop_policy():
|
||||||
|
"""
|
||||||
|
The default `asyncio` event loop policy only automatically creates
|
||||||
|
event loops in the main threads. Other threads must create event
|
||||||
|
loops explicitly or `asyncio.get_event_loop` (and therefore
|
||||||
|
`.IOLoop.current`) will fail. Installing this policy allows event
|
||||||
|
loops to be created automatically on any thread, matching the
|
||||||
|
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
|
||||||
|
# "Any thread" and "selector" should be orthogonal, but there's not a clean
|
||||||
|
# interface for composing policies so pick the right base.
|
||||||
|
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
|
||||||
|
else:
|
||||||
|
_BasePolicy = asyncio.DefaultEventLoopPolicy
|
||||||
|
|
||||||
|
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
|
||||||
|
"""Event loop policy that allows loop creation on any thread.
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_event_loop(self) -> asyncio.AbstractEventLoop:
|
||||||
|
try:
|
||||||
|
return super().get_event_loop()
|
||||||
|
except (RuntimeError, AssertionError):
|
||||||
|
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
|
||||||
|
# and changed to a RuntimeError in 3.4.3.
|
||||||
|
# "There is no current event loop in thread %r"
|
||||||
|
loop = self.new_event_loop()
|
||||||
|
self.set_event_loop(loop)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
|
||||||
|
|
||||||
|
|
||||||
def check_versions():
|
def check_versions():
|
||||||
if shared.cmd_opts.skip_version_check:
|
if shared.cmd_opts.skip_version_check:
|
||||||
return
|
return
|
||||||
|
|
||||||
expected_torch_version = "1.13.1"
|
expected_torch_version = "2.0.0"
|
||||||
|
|
||||||
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
if version.parse(torch.__version__) < version.parse(expected_torch_version):
|
||||||
errors.print_error_explanation(f"""
|
errors.print_error_explanation(f"""
|
||||||
@ -84,7 +130,7 @@ there are reports of issues with training tab on the latest version.
|
|||||||
Use --skip-version-check commandline argument to disable this check.
|
Use --skip-version-check commandline argument to disable this check.
|
||||||
""".strip())
|
""".strip())
|
||||||
|
|
||||||
expected_xformers_version = "0.0.16rc425"
|
expected_xformers_version = "0.0.17"
|
||||||
if shared.xformers_available:
|
if shared.xformers_available:
|
||||||
import xformers
|
import xformers
|
||||||
|
|
||||||
@ -99,12 +145,27 @@ Use --skip-version-check commandline argument to disable this check.
|
|||||||
|
|
||||||
|
|
||||||
def initialize():
|
def initialize():
|
||||||
|
fix_asyncio_event_loop_policy()
|
||||||
|
|
||||||
check_versions()
|
check_versions()
|
||||||
|
|
||||||
extensions.list_extensions()
|
extensions.list_extensions()
|
||||||
localization.list_localizations(cmd_opts.localizations_dir)
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
startup_timer.record("list extensions")
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
|
config_state_file = shared.opts.restore_config_state_file
|
||||||
|
shared.opts.restore_config_state_file = ""
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
if os.path.isfile(config_state_file):
|
||||||
|
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||||
|
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||||
|
config_state = json.load(f)
|
||||||
|
config_states.restore_extension_config(config_state)
|
||||||
|
startup_timer.record("restore extension config")
|
||||||
|
elif config_state_file:
|
||||||
|
print(f"!!! Config state backup not found: {config_state_file}")
|
||||||
|
|
||||||
if cmd_opts.ui_debug_mode:
|
if cmd_opts.ui_debug_mode:
|
||||||
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
|
||||||
modules.scripts.load_scripts()
|
modules.scripts.load_scripts()
|
||||||
@ -126,30 +187,20 @@ def initialize():
|
|||||||
modules.scripts.load_scripts()
|
modules.scripts.load_scripts()
|
||||||
startup_timer.record("load scripts")
|
startup_timer.record("load scripts")
|
||||||
|
|
||||||
modelloader.load_upscalers()
|
|
||||||
startup_timer.record("load upscalers")
|
|
||||||
|
|
||||||
modules.sd_vae.refresh_vae_list()
|
modules.sd_vae.refresh_vae_list()
|
||||||
startup_timer.record("refresh VAE")
|
startup_timer.record("refresh VAE")
|
||||||
|
|
||||||
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
|
||||||
startup_timer.record("refresh textual inversion templates")
|
startup_timer.record("refresh textual inversion templates")
|
||||||
|
|
||||||
try:
|
# load model in parallel to other startup stuff
|
||||||
modules.sd_models.load_model()
|
Thread(target=lambda: shared.sd_model).start()
|
||||||
except Exception as e:
|
|
||||||
errors.display(e, "loading stable diffusion model")
|
|
||||||
print("", file=sys.stderr)
|
|
||||||
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
|
|
||||||
exit(1)
|
|
||||||
startup_timer.record("load SD checkpoint")
|
|
||||||
|
|
||||||
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
|
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
|
||||||
|
|
||||||
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
|
|
||||||
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
startup_timer.record("opts onchange")
|
startup_timer.record("opts onchange")
|
||||||
|
|
||||||
shared.reload_hypernetworks()
|
shared.reload_hypernetworks()
|
||||||
@ -212,6 +263,8 @@ def wait_on_server(demo=None):
|
|||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
demo.close()
|
demo.close()
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
modules.script_callbacks.app_reload_callback()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@ -227,7 +280,6 @@ def api_only():
|
|||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
|
||||||
|
|
||||||
|
|
||||||
def webui():
|
def webui():
|
||||||
launch_api = cmd_opts.api
|
launch_api = cmd_opts.api
|
||||||
initialize()
|
initialize()
|
||||||
@ -254,12 +306,23 @@ def webui():
|
|||||||
for line in file.readlines():
|
for line in file.readlines():
|
||||||
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
|
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
|
||||||
|
|
||||||
|
# this restores the missing /docs endpoint
|
||||||
|
if launch_api and not hasattr(FastAPI, 'original_setup'):
|
||||||
|
def fastapi_setup(self):
|
||||||
|
self.docs_url = "/docs"
|
||||||
|
self.redoc_url = "/redoc"
|
||||||
|
self.original_setup()
|
||||||
|
|
||||||
|
FastAPI.original_setup = FastAPI.setup
|
||||||
|
FastAPI.setup = fastapi_setup
|
||||||
|
|
||||||
app, local_url, share_url = shared.demo.launch(
|
app, local_url, share_url = shared.demo.launch(
|
||||||
share=cmd_opts.share,
|
share=cmd_opts.share,
|
||||||
server_name=server_name,
|
server_name=server_name,
|
||||||
server_port=cmd_opts.port,
|
server_port=cmd_opts.port,
|
||||||
ssl_keyfile=cmd_opts.tls_keyfile,
|
ssl_keyfile=cmd_opts.tls_keyfile,
|
||||||
ssl_certfile=cmd_opts.tls_certfile,
|
ssl_certfile=cmd_opts.tls_certfile,
|
||||||
|
ssl_verify=cmd_opts.disable_tls_verify,
|
||||||
debug=cmd_opts.gradio_debug,
|
debug=cmd_opts.gradio_debug,
|
||||||
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
|
auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None,
|
||||||
inbrowser=cmd_opts.autolaunch,
|
inbrowser=cmd_opts.autolaunch,
|
||||||
@ -290,6 +353,11 @@ def webui():
|
|||||||
|
|
||||||
print(f"Startup time: {startup_timer.summary()}.")
|
print(f"Startup time: {startup_timer.summary()}.")
|
||||||
|
|
||||||
|
if cmd_opts.subpath:
|
||||||
|
redirector = FastAPI()
|
||||||
|
redirector.get("/")
|
||||||
|
mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
|
||||||
|
|
||||||
wait_on_server(shared.demo)
|
wait_on_server(shared.demo)
|
||||||
print('Restarting UI...')
|
print('Restarting UI...')
|
||||||
|
|
||||||
@ -301,6 +369,19 @@ def webui():
|
|||||||
extensions.list_extensions()
|
extensions.list_extensions()
|
||||||
startup_timer.record("list extensions")
|
startup_timer.record("list extensions")
|
||||||
|
|
||||||
|
config_state_file = shared.opts.restore_config_state_file
|
||||||
|
shared.opts.restore_config_state_file = ""
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
if os.path.isfile(config_state_file):
|
||||||
|
print(f"*** About to restore extension state from file: {config_state_file}")
|
||||||
|
with open(config_state_file, "r", encoding="utf-8") as f:
|
||||||
|
config_state = json.load(f)
|
||||||
|
config_states.restore_extension_config(config_state)
|
||||||
|
startup_timer.record("restore extension config")
|
||||||
|
elif config_state_file:
|
||||||
|
print(f"!!! Config state backup not found: {config_state_file}")
|
||||||
|
|
||||||
localization.list_localizations(cmd_opts.localizations_dir)
|
localization.list_localizations(cmd_opts.localizations_dir)
|
||||||
|
|
||||||
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
modelloader.forbid_loaded_nonbuiltin_upscalers()
|
||||||
|
59
webui.sh
59
webui.sh
@ -23,7 +23,7 @@ fi
|
|||||||
# Install directory without trailing slash
|
# Install directory without trailing slash
|
||||||
if [[ -z "${install_dir}" ]]
|
if [[ -z "${install_dir}" ]]
|
||||||
then
|
then
|
||||||
install_dir="/home/$(whoami)"
|
install_dir="$(pwd)"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Name of the subdirectory (defaults to stable-diffusion-webui)
|
# Name of the subdirectory (defaults to stable-diffusion-webui)
|
||||||
@ -113,12 +113,13 @@ case "$gpu_info" in
|
|||||||
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
|
printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half"
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
;;
|
;;
|
||||||
esac
|
esac
|
||||||
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
||||||
then
|
then
|
||||||
export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2"
|
# AMD users will still use torch 1.13 because 2.0 does not seem to work.
|
||||||
|
export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
for preq in "${GIT}" "${python_cmd}"
|
for preq in "${GIT}" "${python_cmd}"
|
||||||
@ -152,35 +153,57 @@ else
|
|||||||
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
||||||
fi
|
fi
|
||||||
|
|
||||||
printf "\n%s\n" "${delimiter}"
|
if [[ -z "${VIRTUAL_ENV}" ]];
|
||||||
printf "Create and activate python venv"
|
|
||||||
printf "\n%s\n" "${delimiter}"
|
|
||||||
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
|
||||||
if [[ ! -d "${venv_dir}" ]]
|
|
||||||
then
|
then
|
||||||
"${python_cmd}" -m venv "${venv_dir}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
first_launch=1
|
printf "Create and activate python venv"
|
||||||
fi
|
printf "\n%s\n" "${delimiter}"
|
||||||
# shellcheck source=/dev/null
|
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
|
||||||
if [[ -f "${venv_dir}"/bin/activate ]]
|
if [[ ! -d "${venv_dir}" ]]
|
||||||
then
|
then
|
||||||
source "${venv_dir}"/bin/activate
|
"${python_cmd}" -m venv "${venv_dir}"
|
||||||
|
first_launch=1
|
||||||
|
fi
|
||||||
|
# shellcheck source=/dev/null
|
||||||
|
if [[ -f "${venv_dir}"/bin/activate ]]
|
||||||
|
then
|
||||||
|
source "${venv_dir}"/bin/activate
|
||||||
|
else
|
||||||
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
|
||||||
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
else
|
else
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
|
printf "python venv already activate: ${VIRTUAL_ENV}"
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Try using TCMalloc on Linux
|
||||||
|
prepare_tcmalloc() {
|
||||||
|
if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then
|
||||||
|
TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)"
|
||||||
|
if [[ ! -z "${TCMALLOC}" ]]; then
|
||||||
|
echo "Using TCMalloc: ${TCMALLOC}"
|
||||||
|
export LD_PRELOAD="${TCMALLOC}"
|
||||||
|
else
|
||||||
|
printf "\e[1m\e[31mCannot locate TCMalloc (improves CPU memory usage)\e[0m\n"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
|
if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]
|
||||||
then
|
then
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Accelerating launch.py..."
|
printf "Accelerating launch.py..."
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
prepare_tcmalloc
|
||||||
exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
|
exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@"
|
||||||
else
|
else
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
printf "Launching launch.py..."
|
printf "Launching launch.py..."
|
||||||
printf "\n%s\n" "${delimiter}"
|
printf "\n%s\n" "${delimiter}"
|
||||||
|
prepare_tcmalloc
|
||||||
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
|
exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@"
|
||||||
fi
|
fi
|
||||||
|
Loading…
Reference in New Issue
Block a user