Merge branch 'dev' into improve-frontend-responsiveness
This commit is contained in:
commit
04b4508a66
21
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
21
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@ -47,6 +47,15 @@ body:
|
|||||||
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
|
||||||
validations:
|
validations:
|
||||||
required: true
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: py-version
|
||||||
|
attributes:
|
||||||
|
label: What Python version are you running on ?
|
||||||
|
multiple: false
|
||||||
|
options:
|
||||||
|
- Python 3.10.x
|
||||||
|
- Python 3.11.x (above, no supported yet)
|
||||||
|
- Python 3.9.x (below, no recommended)
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: platforms
|
id: platforms
|
||||||
attributes:
|
attributes:
|
||||||
@ -59,6 +68,18 @@ body:
|
|||||||
- iOS
|
- iOS
|
||||||
- Android
|
- Android
|
||||||
- Other/Cloud
|
- Other/Cloud
|
||||||
|
- type: dropdown
|
||||||
|
id: device
|
||||||
|
attributes:
|
||||||
|
label: What device are you running WebUI on?
|
||||||
|
multiple: true
|
||||||
|
options:
|
||||||
|
- Nvidia GPUs (RTX 20 above)
|
||||||
|
- Nvidia GPUs (GTX 16 below)
|
||||||
|
- AMD GPUs (RX 6000 above)
|
||||||
|
- AMD GPUs (RX 5000 below)
|
||||||
|
- CPU
|
||||||
|
- Other GPUs
|
||||||
- type: dropdown
|
- type: dropdown
|
||||||
id: browsers
|
id: browsers
|
||||||
attributes:
|
attributes:
|
||||||
|
43
.github/workflows/on_pull_request.yaml
vendored
43
.github/workflows/on_pull_request.yaml
vendored
@ -18,22 +18,29 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout Code
|
- name: Checkout Code
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.10
|
- uses: actions/setup-python@v4
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
with:
|
||||||
python-version: 3.10.6
|
python-version: 3.11
|
||||||
cache: pip
|
# NB: there's no cache: pip here since we're not installing anything
|
||||||
cache-dependency-path: |
|
# from the requirements.txt file(s) in the repository; it's faster
|
||||||
**/requirements*txt
|
# not to have GHA download an (at the time of writing) 4 GB cache
|
||||||
- name: Install PyLint
|
# of PyTorch and other dependencies.
|
||||||
run: |
|
- name: Install Ruff
|
||||||
python -m pip install --upgrade pip
|
run: pip install ruff==0.0.265
|
||||||
pip install pylint
|
- name: Run Ruff
|
||||||
# This lets PyLint check to see if it can resolve imports
|
run: ruff .
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
# The rest are currently disabled pending fixing of e.g. installing the torch dependency.
|
||||||
export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
|
||||||
python launch.py
|
# - name: Install PyLint
|
||||||
- name: Analysing the code with pylint
|
# run: |
|
||||||
run: |
|
# python -m pip install --upgrade pip
|
||||||
pylint $(git ls-files '*.py')
|
# pip install pylint
|
||||||
|
# # This lets PyLint check to see if it can resolve imports
|
||||||
|
# - name: Install dependencies
|
||||||
|
# run: |
|
||||||
|
# export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit"
|
||||||
|
# python launch.py
|
||||||
|
# - name: Analysing the code with pylint
|
||||||
|
# run: |
|
||||||
|
# pylint $(git ls-files '*.py')
|
||||||
|
6
.github/workflows/run_tests.yaml
vendored
6
.github/workflows/run_tests.yaml
vendored
@ -17,8 +17,14 @@ jobs:
|
|||||||
cache: pip
|
cache: pip
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
**/requirements*txt
|
**/requirements*txt
|
||||||
|
launch.py
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
run: python launch.py --tests test --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test
|
||||||
|
env:
|
||||||
|
PIP_DISABLE_PIP_VERSION_CHECK: "1"
|
||||||
|
PIP_PROGRESS_BAR: "off"
|
||||||
|
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
|
||||||
|
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
|
||||||
- name: Upload main app stdout-stderr
|
- name: Upload main app stdout-stderr
|
||||||
uses: actions/upload-artifact@v3
|
uses: actions/upload-artifact@v3
|
||||||
if: always()
|
if: always()
|
||||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -32,4 +32,5 @@ notification.mp3
|
|||||||
/extensions
|
/extensions
|
||||||
/test/stdout.txt
|
/test/stdout.txt
|
||||||
/test/stderr.txt
|
/test/stderr.txt
|
||||||
/cache.json
|
/cache.json*
|
||||||
|
/config_states/
|
||||||
|
120
CHANGELOG.md
Normal file
120
CHANGELOG.md
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
## 1.2.1
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* add an option to always refer to lora by filenames
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* never refer to lora by an alias if multiple loras have same alias or the alias is called none
|
||||||
|
* fix upscalers disappearing after the user reloads UI
|
||||||
|
* allow bf16 in safe unpickler (resolves problems with loading some loras)
|
||||||
|
* allow web UI to be ran fully offline
|
||||||
|
* fix localizations not working
|
||||||
|
* fix error for loras: 'LatentDiffusion' object has no attribute 'lora_layer_mapping'
|
||||||
|
|
||||||
|
## 1.2.0
|
||||||
|
|
||||||
|
### Features:
|
||||||
|
* do not wait for stable diffusion model to load at startup
|
||||||
|
* add filename patterns: [denoising]
|
||||||
|
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
|
||||||
|
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
|
||||||
|
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
|
||||||
|
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
|
||||||
|
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
|
||||||
|
* add version to infotext, footer and console output when starting
|
||||||
|
* add links to wiki for filename pattern settings
|
||||||
|
* add extended info for quicksettings setting and use multiselect input instead of a text field
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* gradio bumped to 3.29.0
|
||||||
|
* torch bumped to 2.0.1
|
||||||
|
* --subpath option for gradio for use with reverse proxy
|
||||||
|
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
|
||||||
|
* possible frontend optimization: do not apply localizations if there are none
|
||||||
|
* Add extra `None` option for VAE in XYZ plot
|
||||||
|
* print error to console when batch processing in img2img fails
|
||||||
|
* create HTML for extra network pages only on demand
|
||||||
|
* allow directories starting with . to still list their models for lora, checkpoints, etc
|
||||||
|
* put infotext options into their own category in settings tab
|
||||||
|
* do not show licenses page when user selects Show all pages in settings
|
||||||
|
|
||||||
|
### Extensions:
|
||||||
|
* Tooltip localization support
|
||||||
|
* Add api method to get LoRA models with prompt
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* re-add /docs endpoint
|
||||||
|
* fix gamepad navigation
|
||||||
|
* make the lightbox fullscreen image function properly
|
||||||
|
* fix squished thumbnails in extras tab
|
||||||
|
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
|
||||||
|
* fix webui showing the same image if you configure the generation to always save results into same file
|
||||||
|
* fix bug with upscalers not working properly
|
||||||
|
* Fix MPS on PyTorch 2.0.1, Intel Macs
|
||||||
|
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
|
||||||
|
* prevent Reload UI button/link from reloading the page when it's not yet ready
|
||||||
|
* fix prompts from file script failing to read contents from a drag/drop file
|
||||||
|
|
||||||
|
|
||||||
|
## 1.1.1
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
|
||||||
|
|
||||||
|
## 1.1.0
|
||||||
|
### Features:
|
||||||
|
* switch to torch 2.0.0 (except for AMD GPUs)
|
||||||
|
* visual improvements to custom code scripts
|
||||||
|
* add filename patterns: [clip_skip], [hasprompt<>], [batch_number], [generation_number]
|
||||||
|
* add support for saving init images in img2img, and record their hashes in infotext for reproducability
|
||||||
|
* automatically select current word when adjusting weight with ctrl+up/down
|
||||||
|
* add dropdowns for X/Y/Z plot
|
||||||
|
* setting: Stable Diffusion/Random number generator source: makes it possible to make images generated from a given manual seed consistent across different GPUs
|
||||||
|
* support Gradio's theme API
|
||||||
|
* use TCMalloc on Linux by default; possible fix for memory leaks
|
||||||
|
* (optimization) option to remove negative conditioning at low sigma values #9177
|
||||||
|
* embed model merge metadata in .safetensors file
|
||||||
|
* extension settings backup/restore feature #9169
|
||||||
|
* add "resize by" and "resize to" tabs to img2img
|
||||||
|
* add option "keep original size" to textual inversion images preprocess
|
||||||
|
* image viewer scrolling via analog stick
|
||||||
|
* button to restore the progress from session lost / tab reload
|
||||||
|
|
||||||
|
### Minor:
|
||||||
|
* gradio bumped to 3.28.1
|
||||||
|
* in extra tab, change extras "scale to" to sliders
|
||||||
|
* add labels to tool buttons to make it possible to hide them
|
||||||
|
* add tiled inference support for ScuNET
|
||||||
|
* add branch support for extension installation
|
||||||
|
* change linux installation script to insall into current directory rather than /home/username
|
||||||
|
* sort textual inversion embeddings by name (case insensitive)
|
||||||
|
* allow styles.csv to be symlinked or mounted in docker
|
||||||
|
* remove the "do not add watermark to images" option
|
||||||
|
* make selected tab configurable with UI config
|
||||||
|
* extra networks UI in now fixed height and scrollable
|
||||||
|
* add disable_tls_verify arg for use with self-signed certs
|
||||||
|
|
||||||
|
### Extensions:
|
||||||
|
* Add reload callback
|
||||||
|
* add is_hr_pass field for processing
|
||||||
|
|
||||||
|
### Bug Fixes:
|
||||||
|
* fix broken batch image processing on 'Extras/Batch Process' tab
|
||||||
|
* add "None" option to extra networks dropdowns
|
||||||
|
* fix FileExistsError for CLIP Interrogator
|
||||||
|
* fix /sdapi/v1/txt2img endpoint not working on Linux #9319
|
||||||
|
* fix disappearing live previews and progressbar during slow tasks
|
||||||
|
* fix fullscreen image view not working properly in some cases
|
||||||
|
* prevent alwayson_scripts args param resizing script_arg list when they are inserted in it
|
||||||
|
* fix prompt schedule for second order samplers
|
||||||
|
* fix image mask/composite for weird resolutions #9628
|
||||||
|
* use correct images for previews when using AND (see #9491)
|
||||||
|
* one broken image in img2img batch won't stop all processing
|
||||||
|
* fix image orientation bug in train/preprocess
|
||||||
|
* fix Ngrok recreating tunnels every reload
|
||||||
|
* fix --realesrgan-models-path and --ldsr-models-path not working
|
||||||
|
* fix --skip-install not working
|
||||||
|
* outpainting Mk2 & Poorman should use the SAMPLE file format to save images, not GRID file format
|
||||||
|
* do not fail all Loras if some have failed to load when making a picture
|
||||||
|
|
||||||
|
## 1.0.0
|
||||||
|
* everything
|
14
README.md
14
README.md
@ -99,8 +99,14 @@ Alternatively, use online services (like Google Colab):
|
|||||||
|
|
||||||
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
|
||||||
|
|
||||||
|
### Installation on Windows 10/11 with NVidia-GPUs using release package
|
||||||
|
1. Download `sd.webui.zip` from [v1.0.0-pre](https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/tag/v1.0.0-pre) and extract it's contents.
|
||||||
|
2. Run `update.bat`.
|
||||||
|
3. Run `run.bat`.
|
||||||
|
> For more details see [Install-and-Run-on-NVidia-GPUs](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs)
|
||||||
|
|
||||||
### Automatic Installation on Windows
|
### Automatic Installation on Windows
|
||||||
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH".
|
1. Install [Python 3.10.6](https://www.python.org/downloads/release/python-3106/) (Newer version of Python does not support torch), checking "Add Python to PATH".
|
||||||
2. Install [git](https://git-scm.com/download/win).
|
2. Install [git](https://git-scm.com/download/win).
|
||||||
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`.
|
||||||
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
4. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user.
|
||||||
@ -115,11 +121,12 @@ sudo dnf install wget git python3
|
|||||||
# Arch-based:
|
# Arch-based:
|
||||||
sudo pacman -S wget git python3
|
sudo pacman -S wget git python3
|
||||||
```
|
```
|
||||||
2. To install in `/home/$(whoami)/stable-diffusion-webui/`, run:
|
2. Navigate to the directory you would like the webui to be installed and execute the following command:
|
||||||
```bash
|
```bash
|
||||||
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/master/webui.sh)
|
||||||
```
|
```
|
||||||
3. Run `webui.sh`.
|
3. Run `webui.sh`.
|
||||||
|
4. Check `webui-user.sh` for options.
|
||||||
### Installation on Apple Silicon
|
### Installation on Apple Silicon
|
||||||
|
|
||||||
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
|
||||||
@ -157,5 +164,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
|
|||||||
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
|
||||||
- Security advice - RyotaK
|
- Security advice - RyotaK
|
||||||
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
|
||||||
|
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
|
||||||
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
|
||||||
- (You)
|
- (You)
|
||||||
|
@ -4,8 +4,8 @@ channels:
|
|||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
- python=3.10
|
- python=3.10
|
||||||
- pip=22.2.2
|
- pip=23.0
|
||||||
- cudatoolkit=11.3
|
- cudatoolkit=11.8
|
||||||
- pytorch=1.12.1
|
- pytorch=2.0
|
||||||
- torchvision=0.13.1
|
- torchvision=0.15
|
||||||
- numpy=1.23.1
|
- numpy=1.23
|
||||||
|
@ -88,7 +88,7 @@ class LDSR:
|
|||||||
|
|
||||||
x_t = None
|
x_t = None
|
||||||
logs = None
|
logs = None
|
||||||
for n in range(n_runs):
|
for _ in range(n_runs):
|
||||||
if custom_shape is not None:
|
if custom_shape is not None:
|
||||||
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
|
||||||
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
|
||||||
@ -110,7 +110,6 @@ class LDSR:
|
|||||||
diffusion_steps = int(steps)
|
diffusion_steps = int(steps)
|
||||||
eta = 1.0
|
eta = 1.0
|
||||||
|
|
||||||
down_sample_method = 'Lanczos'
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if torch.cuda.is_available:
|
if torch.cuda.is_available:
|
||||||
@ -131,11 +130,11 @@ class LDSR:
|
|||||||
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
|
||||||
else:
|
else:
|
||||||
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
|
||||||
|
|
||||||
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
|
||||||
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
|
||||||
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
|
||||||
|
|
||||||
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
|
||||||
|
|
||||||
sample = logs["sample"]
|
sample = logs["sample"]
|
||||||
@ -158,7 +157,7 @@ class LDSR:
|
|||||||
|
|
||||||
|
|
||||||
def get_cond(selected_path):
|
def get_cond(selected_path):
|
||||||
example = dict()
|
example = {}
|
||||||
up_f = 4
|
up_f = 4
|
||||||
c = selected_path.convert('RGB')
|
c = selected_path.convert('RGB')
|
||||||
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
|
||||||
@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
|
||||||
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
|
||||||
log = dict()
|
log = {}
|
||||||
|
|
||||||
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize
|
|||||||
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
|
||||||
log["sample_noquant"] = x_sample_noquant
|
log["sample_noquant"] = x_sample_noquant
|
||||||
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log["sample"] = x_sample
|
log["sample"] = x_sample
|
||||||
|
@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from ldsr_model_arch import LDSR
|
from ldsr_model_arch import LDSR
|
||||||
from modules import shared, script_callbacks
|
from modules import shared, script_callbacks
|
||||||
import sd_hijack_autoencoder, sd_hijack_ddpm_v1
|
import sd_hijack_autoencoder # noqa: F401
|
||||||
|
import sd_hijack_ddpm_v1 # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
class UpscalerLDSR(Upscaler):
|
class UpscalerLDSR(Upscaler):
|
||||||
@ -25,22 +26,28 @@ class UpscalerLDSR(Upscaler):
|
|||||||
yaml_path = os.path.join(self.model_path, "project.yaml")
|
yaml_path = os.path.join(self.model_path, "project.yaml")
|
||||||
old_model_path = os.path.join(self.model_path, "model.pth")
|
old_model_path = os.path.join(self.model_path, "model.pth")
|
||||||
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
new_model_path = os.path.join(self.model_path, "model.ckpt")
|
||||||
safetensors_model_path = os.path.join(self.model_path, "model.safetensors")
|
|
||||||
|
local_model_paths = self.find_models(ext_filter=[".ckpt", ".safetensors"])
|
||||||
|
local_ckpt_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.ckpt")]), None)
|
||||||
|
local_safetensors_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("model.safetensors")]), None)
|
||||||
|
local_yaml_path = next(iter([local_model for local_model in local_model_paths if local_model.endswith("project.yaml")]), None)
|
||||||
|
|
||||||
if os.path.exists(yaml_path):
|
if os.path.exists(yaml_path):
|
||||||
statinfo = os.stat(yaml_path)
|
statinfo = os.stat(yaml_path)
|
||||||
if statinfo.st_size >= 10485760:
|
if statinfo.st_size >= 10485760:
|
||||||
print("Removing invalid LDSR YAML file.")
|
print("Removing invalid LDSR YAML file.")
|
||||||
os.remove(yaml_path)
|
os.remove(yaml_path)
|
||||||
|
|
||||||
if os.path.exists(old_model_path):
|
if os.path.exists(old_model_path):
|
||||||
print("Renaming model from model.pth to model.ckpt")
|
print("Renaming model from model.pth to model.ckpt")
|
||||||
os.rename(old_model_path, new_model_path)
|
os.rename(old_model_path, new_model_path)
|
||||||
if os.path.exists(safetensors_model_path):
|
|
||||||
model = safetensors_model_path
|
if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
|
||||||
|
model = local_safetensors_path
|
||||||
else:
|
else:
|
||||||
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
|
||||||
file_name="model.ckpt", progress=True)
|
|
||||||
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
|
yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
|
||||||
file_name="project.yaml", progress=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return LDSR(model, yaml)
|
return LDSR(model, yaml)
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo
|
||||||
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo
|
||||||
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from torch.optim.lr_scheduler import LambdaLR
|
||||||
|
|
||||||
|
from ldm.modules.ema import LitEma
|
||||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
import ldm.models.autoencoder
|
import ldm.models.autoencoder
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
class VQModel(pl.LightningModule):
|
class VQModel(pl.LightningModule):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -19,7 +24,7 @@ class VQModel(pl.LightningModule):
|
|||||||
n_embed,
|
n_embed,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
image_key="image",
|
image_key="image",
|
||||||
colorize_nlabels=None,
|
colorize_nlabels=None,
|
||||||
monitor=None,
|
monitor=None,
|
||||||
@ -57,7 +62,7 @@ class VQModel(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [])
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.lr_g_factor = lr_g_factor
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
@ -76,11 +81,11 @@ class VQModel(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
def init_from_ckpt(self, path, ignore_keys=None):
|
||||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -165,7 +170,7 @@ class VQModel(pl.LightningModule):
|
|||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
log_dict = self._validation_step(batch, batch_idx)
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
return log_dict
|
return log_dict
|
||||||
|
|
||||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
@ -232,7 +237,7 @@ class VQModel(pl.LightningModule):
|
|||||||
return self.decoder.conv_out.weight
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.image_key)
|
x = self.get_input(batch, self.image_key)
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
if only_inputs:
|
if only_inputs:
|
||||||
@ -249,7 +254,8 @@ class VQModel(pl.LightningModule):
|
|||||||
if plot_ema:
|
if plot_ema:
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
xrec_ema, _ = self(x)
|
xrec_ema, _ = self(x)
|
||||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
if x.shape[1] > 3:
|
||||||
|
xrec_ema = self.to_rgb(xrec_ema)
|
||||||
log["reconstructions_ema"] = xrec_ema
|
log["reconstructions_ema"] = xrec_ema
|
||||||
return log
|
return log
|
||||||
|
|
||||||
@ -264,7 +270,7 @@ class VQModel(pl.LightningModule):
|
|||||||
|
|
||||||
class VQModelInterface(VQModel):
|
class VQModelInterface(VQModel):
|
||||||
def __init__(self, embed_dim, *args, **kwargs):
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
super().__init__(*args, embed_dim=embed_dim, **kwargs)
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
@ -282,5 +288,5 @@ class VQModelInterface(VQModel):
|
|||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
setattr(ldm.models.autoencoder, "VQModel", VQModel)
|
ldm.models.autoencoder.VQModel = VQModel
|
||||||
setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface)
|
ldm.models.autoencoder.VQModelInterface = VQModelInterface
|
||||||
|
@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if monitor is not None:
|
if monitor is not None:
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
|
||||||
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
keys = list(sd.keys())
|
keys = list(sd.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys or []:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
self.instantiate_cond_stage(cond_stage_config)
|
self.instantiate_cond_stage(cond_stage_config)
|
||||||
self.cond_stage_forward = cond_stage_forward
|
self.cond_stage_forward = cond_stage_forward
|
||||||
self.clip_denoised = False
|
self.clip_denoised = False
|
||||||
self.bbox_tokenizer = None
|
self.bbox_tokenizer = None
|
||||||
|
|
||||||
self.restarted_from_ckpt = False
|
self.restarted_from_ckpt = False
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
|
||||||
|
|
||||||
# 2. apply model loop over last dim
|
# 2. apply model loop over last dim
|
||||||
if isinstance(self.first_stage_model, VQModelInterface):
|
if isinstance(self.first_stage_model, VQModelInterface):
|
||||||
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
|
||||||
force_not_quantize=predict_cids or force_not_quantize)
|
force_not_quantize=predict_cids or force_not_quantize)
|
||||||
for i in range(z.shape[-1])]
|
for i in range(z.shape[-1])]
|
||||||
@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if hasattr(self, "split_input_params"):
|
if hasattr(self, "split_input_params"):
|
||||||
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
assert len(cond) == 1 # todo can only deal with one conditioning atm
|
||||||
assert not return_ids
|
assert not return_ids
|
||||||
ks = self.split_input_params["ks"] # eg. (128, 128)
|
ks = self.split_input_params["ks"] # eg. (128, 128)
|
||||||
stride = self.split_input_params["stride"] # eg. (64, 64)
|
stride = self.split_input_params["stride"] # eg. (64, 64)
|
||||||
|
|
||||||
@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
use_ddim = ddim_steps is not None
|
use_ddim = ddim_steps is not None
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1):
|
|||||||
logs['bbox_image'] = cond_img
|
logs['bbox_image'] = cond_img
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1)
|
ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1)
|
ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1)
|
ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1
|
||||||
setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1)
|
ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from modules import extra_networks, shared
|
from modules import extra_networks, shared
|
||||||
import lora
|
import lora
|
||||||
|
|
||||||
|
|
||||||
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__('lora')
|
super().__init__('lora')
|
||||||
@ -8,7 +9,7 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
|
|||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_lora
|
additional = shared.opts.sd_lora
|
||||||
|
|
||||||
if additional != "" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from modules import shared, devices, sd_models, errors
|
from modules import shared, devices, sd_models, errors, scripts
|
||||||
|
|
||||||
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
|
||||||
|
|
||||||
@ -93,6 +92,7 @@ class LoraOnDisk:
|
|||||||
self.metadata = m
|
self.metadata = m
|
||||||
|
|
||||||
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
|
||||||
|
self.alias = self.metadata.get('ss_output_name', self.name)
|
||||||
|
|
||||||
|
|
||||||
class LoraModule:
|
class LoraModule:
|
||||||
@ -132,6 +132,10 @@ def load_lora(name, filename):
|
|||||||
|
|
||||||
sd = sd_models.read_state_dict(filename)
|
sd = sd_models.read_state_dict(filename)
|
||||||
|
|
||||||
|
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0
|
||||||
|
if not hasattr(shared.sd_model, 'lora_layer_mapping'):
|
||||||
|
assign_lora_names_to_compvis_modules(shared.sd_model)
|
||||||
|
|
||||||
keys_failed_to_match = {}
|
keys_failed_to_match = {}
|
||||||
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
|
||||||
|
|
||||||
@ -165,12 +169,14 @@ def load_lora(name, filename):
|
|||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.MultiheadAttention:
|
elif type(sd_module) == torch.nn.MultiheadAttention:
|
||||||
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
|
||||||
elif type(sd_module) == torch.nn.Conv2d:
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
|
||||||
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
|
||||||
|
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
|
||||||
|
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
|
||||||
else:
|
else:
|
||||||
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
|
||||||
continue
|
continue
|
||||||
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
|
raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
module.weight.copy_(weight)
|
module.weight.copy_(weight)
|
||||||
@ -182,7 +188,7 @@ def load_lora(name, filename):
|
|||||||
elif lora_key == "lora_down.weight":
|
elif lora_key == "lora_down.weight":
|
||||||
lora_module.down = module
|
lora_module.down = module
|
||||||
else:
|
else:
|
||||||
assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
|
raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
|
||||||
|
|
||||||
if len(keys_failed_to_match) > 0:
|
if len(keys_failed_to_match) > 0:
|
||||||
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
|
||||||
@ -199,11 +205,11 @@ def load_loras(names, multipliers=None):
|
|||||||
|
|
||||||
loaded_loras.clear()
|
loaded_loras.clear()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
if any([x is None for x in loras_on_disk]):
|
if any(x is None for x in loras_on_disk):
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
|
||||||
loras_on_disk = [available_loras.get(name, None) for name in names]
|
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
|
||||||
|
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
lora = already_loaded.get(name, None)
|
lora = already_loaded.get(name, None)
|
||||||
@ -211,7 +217,11 @@ def load_loras(names, multipliers=None):
|
|||||||
lora_on_disk = loras_on_disk[i]
|
lora_on_disk = loras_on_disk[i]
|
||||||
if lora_on_disk is not None:
|
if lora_on_disk is not None:
|
||||||
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:
|
||||||
lora = load_lora(name, lora_on_disk.filename)
|
try:
|
||||||
|
lora = load_lora(name, lora_on_disk.filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"loading Lora {lora_on_disk.filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
if lora is None:
|
if lora is None:
|
||||||
print(f"Couldn't find Lora with name {name}")
|
print(f"Couldn't find Lora with name {name}")
|
||||||
@ -228,6 +238,8 @@ def lora_calc_updown(lora, module, target):
|
|||||||
|
|
||||||
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
|
||||||
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
|
||||||
|
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
|
||||||
else:
|
else:
|
||||||
updown = up @ down
|
updown = up @ down
|
||||||
|
|
||||||
@ -236,6 +248,19 @@ def lora_calc_updown(lora, module, target):
|
|||||||
return updown
|
return updown
|
||||||
|
|
||||||
|
|
||||||
|
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
|
weights_backup = getattr(self, "lora_weights_backup", None)
|
||||||
|
|
||||||
|
if weights_backup is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(self, torch.nn.MultiheadAttention):
|
||||||
|
self.in_proj_weight.copy_(weights_backup[0])
|
||||||
|
self.out_proj.weight.copy_(weights_backup[1])
|
||||||
|
else:
|
||||||
|
self.weight.copy_(weights_backup)
|
||||||
|
|
||||||
|
|
||||||
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
|
||||||
"""
|
"""
|
||||||
Applies the currently selected set of Loras to the weights of torch layer self.
|
Applies the currently selected set of Loras to the weights of torch layer self.
|
||||||
@ -260,12 +285,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
self.lora_weights_backup = weights_backup
|
self.lora_weights_backup = weights_backup
|
||||||
|
|
||||||
if current_names != wanted_names:
|
if current_names != wanted_names:
|
||||||
if weights_backup is not None:
|
lora_restore_weights_from_backup(self)
|
||||||
if isinstance(self, torch.nn.MultiheadAttention):
|
|
||||||
self.in_proj_weight.copy_(weights_backup[0])
|
|
||||||
self.out_proj.weight.copy_(weights_backup[1])
|
|
||||||
else:
|
|
||||||
self.weight.copy_(weights_backup)
|
|
||||||
|
|
||||||
for lora in loaded_loras:
|
for lora in loaded_loras:
|
||||||
module = lora.modules.get(lora_layer_name, None)
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
@ -293,15 +313,48 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
|
|||||||
|
|
||||||
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
print(f'failed to calculate lora weights for layer {lora_layer_name}')
|
||||||
|
|
||||||
setattr(self, "lora_current_names", wanted_names)
|
self.lora_current_names = wanted_names
|
||||||
|
|
||||||
|
|
||||||
|
def lora_forward(module, input, original_forward):
|
||||||
|
"""
|
||||||
|
Old way of applying Lora by executing operations during layer's forward.
|
||||||
|
Stacking many loras this way results in big performance degradation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(loaded_loras) == 0:
|
||||||
|
return original_forward(module, input)
|
||||||
|
|
||||||
|
input = devices.cond_cast_unet(input)
|
||||||
|
|
||||||
|
lora_restore_weights_from_backup(module)
|
||||||
|
lora_reset_cached_weight(module)
|
||||||
|
|
||||||
|
res = original_forward(module, input)
|
||||||
|
|
||||||
|
lora_layer_name = getattr(module, 'lora_layer_name', None)
|
||||||
|
for lora in loaded_loras:
|
||||||
|
module = lora.modules.get(lora_layer_name, None)
|
||||||
|
if module is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
module.up.to(device=devices.device)
|
||||||
|
module.down.to(device=devices.device)
|
||||||
|
|
||||||
|
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
|
||||||
setattr(self, "lora_current_names", ())
|
self.lora_current_names = ()
|
||||||
setattr(self, "lora_weights_backup", None)
|
self.lora_weights_backup = None
|
||||||
|
|
||||||
|
|
||||||
def lora_Linear_forward(self, input):
|
def lora_Linear_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Linear_forward_before_lora(self, input)
|
return torch.nn.Linear_forward_before_lora(self, input)
|
||||||
@ -314,6 +367,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def lora_Conv2d_forward(self, input):
|
def lora_Conv2d_forward(self, input):
|
||||||
|
if shared.opts.lora_functional:
|
||||||
|
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
|
||||||
|
|
||||||
lora_apply_weights(self)
|
lora_apply_weights(self)
|
||||||
|
|
||||||
return torch.nn.Conv2d_forward_before_lora(self, input)
|
return torch.nn.Conv2d_forward_before_lora(self, input)
|
||||||
@ -339,24 +395,65 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
|
|||||||
|
|
||||||
def list_available_loras():
|
def list_available_loras():
|
||||||
available_loras.clear()
|
available_loras.clear()
|
||||||
|
available_lora_aliases.clear()
|
||||||
|
forbidden_lora_aliases.clear()
|
||||||
|
forbidden_lora_aliases.update({"none": 1})
|
||||||
|
|
||||||
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
|
||||||
|
|
||||||
candidates = \
|
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
|
|
||||||
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
|
|
||||||
|
|
||||||
for filename in sorted(candidates, key=str.lower):
|
for filename in sorted(candidates, key=str.lower):
|
||||||
if os.path.isdir(filename):
|
if os.path.isdir(filename):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
name = os.path.splitext(os.path.basename(filename))[0]
|
name = os.path.splitext(os.path.basename(filename))[0]
|
||||||
|
entry = LoraOnDisk(name, filename)
|
||||||
|
|
||||||
available_loras[name] = LoraOnDisk(name, filename)
|
available_loras[name] = entry
|
||||||
|
|
||||||
|
if entry.alias in available_lora_aliases:
|
||||||
|
forbidden_lora_aliases[entry.alias.lower()] = 1
|
||||||
|
|
||||||
|
available_lora_aliases[name] = entry
|
||||||
|
available_lora_aliases[entry.alias] = entry
|
||||||
|
|
||||||
|
|
||||||
|
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
|
||||||
|
|
||||||
|
|
||||||
|
def infotext_pasted(infotext, params):
|
||||||
|
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
|
||||||
|
return # if the other extension is active, it will handle those fields, no need to do anything
|
||||||
|
|
||||||
|
added = []
|
||||||
|
|
||||||
|
for k in params:
|
||||||
|
if not k.startswith("AddNet Model "):
|
||||||
|
continue
|
||||||
|
|
||||||
|
num = k[13:]
|
||||||
|
|
||||||
|
if params.get("AddNet Module " + num) != "LoRA":
|
||||||
|
continue
|
||||||
|
|
||||||
|
name = params.get("AddNet Model " + num)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re_lora_name.match(name)
|
||||||
|
if m:
|
||||||
|
name = m.group(1)
|
||||||
|
|
||||||
|
multiplier = params.get("AddNet Weight A " + num, "1.0")
|
||||||
|
|
||||||
|
added.append(f"<lora:{name}:{multiplier}>")
|
||||||
|
|
||||||
|
if added:
|
||||||
|
params["Prompt"] += "\n" + "".join(added)
|
||||||
|
|
||||||
available_loras = {}
|
available_loras = {}
|
||||||
|
available_lora_aliases = {}
|
||||||
|
forbidden_lora_aliases = {}
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
|
|
||||||
list_available_loras()
|
list_available_loras()
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
import lora
|
import lora
|
||||||
import extra_networks_lora
|
import extra_networks_lora
|
||||||
import ui_extra_networks_lora
|
import ui_extra_networks_lora
|
||||||
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
|
||||||
|
|
||||||
|
|
||||||
def unload():
|
def unload():
|
||||||
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
|
||||||
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
|
||||||
@ -49,8 +49,34 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
|
|||||||
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
|
||||||
script_callbacks.on_script_unloaded(unload)
|
script_callbacks.on_script_unloaded(unload)
|
||||||
script_callbacks.on_before_ui(before_ui)
|
script_callbacks.on_before_ui(before_ui)
|
||||||
|
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
|
||||||
|
|
||||||
|
|
||||||
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
|
||||||
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": [""] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
|
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
|
||||||
|
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
|
||||||
|
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
|
def create_lora_json(obj: lora.LoraOnDisk):
|
||||||
|
return {
|
||||||
|
"name": obj.name,
|
||||||
|
"alias": obj.alias,
|
||||||
|
"path": obj.filename,
|
||||||
|
"metadata": obj.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def api_loras(_: gr.Blocks, app: FastAPI):
|
||||||
|
@app.get("/sdapi/v1/loras")
|
||||||
|
async def get_loras():
|
||||||
|
return [create_lora_json(obj) for obj in lora.available_loras.values()]
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_app_started(api_loras)
|
||||||
|
|
||||||
|
@ -15,13 +15,19 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
|
|||||||
def list_items(self):
|
def list_items(self):
|
||||||
for name, lora_on_disk in lora.available_loras.items():
|
for name, lora_on_disk in lora.available_loras.items():
|
||||||
path, ext = os.path.splitext(lora_on_disk.filename)
|
path, ext = os.path.splitext(lora_on_disk.filename)
|
||||||
|
|
||||||
|
if shared.opts.lora_preferred_name == "Filename" or lora_on_disk.alias.lower() in lora.forbidden_lora_aliases:
|
||||||
|
alias = name
|
||||||
|
else:
|
||||||
|
alias = lora_on_disk.alias
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"name": name,
|
"name": name,
|
||||||
"filename": path,
|
"filename": path,
|
||||||
"preview": self.find_preview(path),
|
"preview": self.find_preview(path),
|
||||||
"description": self.find_description(path),
|
"description": self.find_description(path),
|
||||||
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
"search_term": self.search_terms_from_path(lora_on_disk.filename),
|
||||||
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
"prompt": json.dumps(f"<lora:{alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
|
||||||
"local_preview": f"{path}.{shared.opts.samples_format}",
|
"local_preview": f"{path}.{shared.opts.samples_format}",
|
||||||
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,14 @@ import traceback
|
|||||||
import PIL.Image
|
import PIL.Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.upscaler
|
import modules.upscaler
|
||||||
from modules import devices, modelloader
|
from modules import devices, modelloader, script_callbacks
|
||||||
from scunet_model_arch import SCUNet as net
|
from scunet_model_arch import SCUNet as net
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
|
|
||||||
class UpscalerScuNET(modules.upscaler.Upscaler):
|
class UpscalerScuNET(modules.upscaler.Upscaler):
|
||||||
@ -42,28 +45,78 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
scalers.append(scaler_data2)
|
scalers.append(scaler_data2)
|
||||||
self.scalers = scalers
|
self.scalers = scalers
|
||||||
|
|
||||||
def do_upscale(self, img: PIL.Image, selected_file):
|
@staticmethod
|
||||||
|
@torch.no_grad()
|
||||||
|
def tiled_inference(img, model):
|
||||||
|
# test the image tile by tile
|
||||||
|
h, w = img.shape[2:]
|
||||||
|
tile = opts.SCUNET_tile
|
||||||
|
tile_overlap = opts.SCUNET_tile_overlap
|
||||||
|
if tile == 0:
|
||||||
|
return model(img)
|
||||||
|
|
||||||
|
device = devices.get_device_for('scunet')
|
||||||
|
assert tile % 8 == 0, "tile size should be a multiple of window_size"
|
||||||
|
sf = 1
|
||||||
|
|
||||||
|
stride = tile - tile_overlap
|
||||||
|
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
|
||||||
|
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
|
||||||
|
E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
|
||||||
|
W = torch.zeros_like(E, dtype=devices.dtype, device=device)
|
||||||
|
|
||||||
|
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
|
||||||
|
for h_idx in h_idx_list:
|
||||||
|
|
||||||
|
for w_idx in w_idx_list:
|
||||||
|
|
||||||
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
|
|
||||||
|
out_patch = model(in_patch)
|
||||||
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
|
||||||
|
E[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch)
|
||||||
|
W[
|
||||||
|
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
|
||||||
|
].add_(out_patch_mask)
|
||||||
|
pbar.update(1)
|
||||||
|
output = E.div_(W)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def do_upscale(self, img: PIL.Image.Image, selected_file):
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = self.load_model(selected_file)
|
model = self.load_model(selected_file)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
img = np.array(img)
|
tile = opts.SCUNET_tile
|
||||||
img = img[:, :, ::-1]
|
h, w = img.height, img.width
|
||||||
img = np.moveaxis(img, 2, 0) / 255
|
np_img = np.array(img)
|
||||||
img = torch.from_numpy(img).float()
|
np_img = np_img[:, :, ::-1] # RGB to BGR
|
||||||
img = img.unsqueeze(0).to(device)
|
np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW
|
||||||
|
torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore
|
||||||
|
|
||||||
with torch.no_grad():
|
if tile > h or tile > w:
|
||||||
output = model(img)
|
_img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device)
|
||||||
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
|
_img[:, :, :h, :w] = torch_img # pad image
|
||||||
output = 255. * np.moveaxis(output, 0, 2)
|
torch_img = _img
|
||||||
output = output.astype(np.uint8)
|
|
||||||
output = output[:, :, ::-1]
|
torch_output = self.tiled_inference(torch_img, model).squeeze(0)
|
||||||
|
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
|
||||||
|
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
|
||||||
|
del torch_img, torch_output
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return PIL.Image.fromarray(output, 'RGB')
|
|
||||||
|
output = np_output.transpose((1, 2, 0)) # CHW to HWC
|
||||||
|
output = output[:, :, ::-1] # BGR to RGB
|
||||||
|
return PIL.Image.fromarray((output * 255).astype(np.uint8))
|
||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
device = devices.get_device_for('scunet')
|
device = devices.get_device_for('scunet')
|
||||||
@ -79,9 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
|
|||||||
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
|
||||||
model.load_state_dict(torch.load(filename), strict=True)
|
model.load_state_dict(torch.load(filename), strict=True)
|
||||||
model.eval()
|
model.eval()
|
||||||
for k, v in model.named_parameters():
|
for _, v in model.named_parameters():
|
||||||
v.requires_grad = False
|
v.requires_grad = False
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def on_ui_settings():
|
||||||
|
import gradio as gr
|
||||||
|
from modules import shared
|
||||||
|
|
||||||
|
shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling"))
|
||||||
|
shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam"))
|
||||||
|
|
||||||
|
|
||||||
|
script_callbacks.on_ui_settings(on_ui_settings)
|
||||||
|
@ -61,7 +61,9 @@ class WMSA(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output: tensor shape [b h w c]
|
output: tensor shape [b h w c]
|
||||||
"""
|
"""
|
||||||
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
if self.type != 'W':
|
||||||
|
x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
|
||||||
|
|
||||||
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
|
||||||
h_windows = x.size(1)
|
h_windows = x.size(1)
|
||||||
w_windows = x.size(2)
|
w_windows = x.size(2)
|
||||||
@ -85,8 +87,9 @@ class WMSA(nn.Module):
|
|||||||
output = self.linear(output)
|
output = self.linear(output)
|
||||||
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
|
||||||
|
|
||||||
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
|
if self.type != 'W':
|
||||||
dims=(1, 2))
|
output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2))
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def relative_embedding(self):
|
def relative_embedding(self):
|
||||||
@ -262,4 +265,4 @@ class SCUNet(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.LayerNorm):
|
elif isinstance(m, nn.LayerNorm):
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modules import modelloader, devices, script_callbacks, shared
|
from modules import modelloader, devices, script_callbacks, shared
|
||||||
from modules.shared import cmd_opts, opts, state
|
from modules.shared import opts, state
|
||||||
from swinir_model_arch import SwinIR as net
|
from swinir_model_arch import SwinIR as net
|
||||||
from swinir_model_arch_v2 import Swin2SR as net2
|
from swinir_model_arch_v2 import Swin2SR as net2
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler):
|
|||||||
img = upscale(img, model)
|
img = upscale(img, model)
|
||||||
try:
|
try:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return img
|
return img
|
||||||
|
|
||||||
@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
|
|||||||
for w_idx in w_idx_list:
|
for w_idx in w_idx_list:
|
||||||
if state.interrupted or state.skipped:
|
if state.interrupted or state.skipped:
|
||||||
break
|
break
|
||||||
|
|
||||||
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
|
||||||
out_patch = model(in_patch)
|
out_patch = model(in_patch)
|
||||||
out_patch_mask = torch.ones_like(out_patch)
|
out_patch_mask = torch.ones_like(out_patch)
|
||||||
|
@ -644,7 +644,7 @@ class SwinIR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
@ -805,7 +805,7 @@ class SwinIR(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
x = self.check_image_size(x)
|
x = self.check_image_size(x)
|
||||||
|
|
||||||
self.mean = self.mean.type_as(x)
|
self.mean = self.mean.type_as(x)
|
||||||
x = (x - self.mean) * self.img_range
|
x = (x - self.mean) * self.img_range
|
||||||
|
|
||||||
@ -844,7 +844,7 @@ class SwinIR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
|
@ -74,7 +74,7 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
|
||||||
pretrained_window_size=[0, 0]):
|
pretrained_window_size=(0, 0)):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
self.register_buffer("attn_mask", attn_mask)
|
self.register_buffer("attn_mask", attn_mask)
|
||||||
|
|
||||||
def calculate_mask(self, x_size):
|
def calculate_mask(self, x_size):
|
||||||
# calculate attention mask for SW-MSA
|
# calculate attention mask for SW-MSA
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
return attn_mask
|
return attn_mask
|
||||||
|
|
||||||
def forward(self, x, x_size):
|
def forward(self, x, x_size):
|
||||||
H, W = x_size
|
H, W = x_size
|
||||||
@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module):
|
|||||||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||||
else:
|
else:
|
||||||
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||||
|
|
||||||
# merge windows
|
# merge windows
|
||||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
|
||||||
@ -369,7 +369,7 @@ class PatchMerging(nn.Module):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||||
flops += H * W * self.dim // 2
|
flops += H * W * self.dim // 2
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class BasicLayer(nn.Module):
|
class BasicLayer(nn.Module):
|
||||||
""" A basic Swin Transformer layer for one stage.
|
""" A basic Swin Transformer layer for one stage.
|
||||||
@ -447,7 +447,7 @@ class BasicLayer(nn.Module):
|
|||||||
nn.init.constant_(blk.norm1.weight, 0)
|
nn.init.constant_(blk.norm1.weight, 0)
|
||||||
nn.init.constant_(blk.norm2.bias, 0)
|
nn.init.constant_(blk.norm2.bias, 0)
|
||||||
nn.init.constant_(blk.norm2.weight, 0)
|
nn.init.constant_(blk.norm2.weight, 0)
|
||||||
|
|
||||||
class PatchEmbed(nn.Module):
|
class PatchEmbed(nn.Module):
|
||||||
r""" Image to Patch Embedding
|
r""" Image to Patch Embedding
|
||||||
Args:
|
Args:
|
||||||
@ -492,7 +492,7 @@ class PatchEmbed(nn.Module):
|
|||||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||||
if self.norm is not None:
|
if self.norm is not None:
|
||||||
flops += Ho * Wo * self.embed_dim
|
flops += Ho * Wo * self.embed_dim
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
class RSTB(nn.Module):
|
class RSTB(nn.Module):
|
||||||
"""Residual Swin Transformer Block (RSTB).
|
"""Residual Swin Transformer Block (RSTB).
|
||||||
@ -531,7 +531,7 @@ class RSTB(nn.Module):
|
|||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=mlp_ratio,
|
mlp_ratio=mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop, attn_drop=attn_drop,
|
drop=drop, attn_drop=attn_drop,
|
||||||
drop_path=drop_path,
|
drop_path=drop_path,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -622,7 +622,7 @@ class Upsample(nn.Sequential):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample, self).__init__(*m)
|
super(Upsample, self).__init__(*m)
|
||||||
|
|
||||||
class Upsample_hf(nn.Sequential):
|
class Upsample_hf(nn.Sequential):
|
||||||
"""Upsample module.
|
"""Upsample module.
|
||||||
|
|
||||||
@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential):
|
|||||||
m.append(nn.PixelShuffle(3))
|
m.append(nn.PixelShuffle(3))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
|
||||||
super(Upsample_hf, self).__init__(*m)
|
super(Upsample_hf, self).__init__(*m)
|
||||||
|
|
||||||
|
|
||||||
class UpsampleOneStep(nn.Sequential):
|
class UpsampleOneStep(nn.Sequential):
|
||||||
@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential):
|
|||||||
H, W = self.input_resolution
|
H, W = self.input_resolution
|
||||||
flops = H * W * self.num_feat * 3 * 9
|
flops = H * W * self.num_feat * 3 * 9
|
||||||
return flops
|
return flops
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Swin2SR(nn.Module):
|
class Swin2SR(nn.Module):
|
||||||
r""" Swin2SR
|
r""" Swin2SR
|
||||||
@ -698,8 +698,8 @@ class Swin2SR(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
def __init__(self, img_size=64, patch_size=1, in_chans=3,
|
||||||
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
|
embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
|
||||||
window_size=7, mlp_ratio=4., qkv_bias=True,
|
window_size=7, mlp_ratio=4., qkv_bias=True,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
|
||||||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
|
||||||
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
|
||||||
@ -764,7 +764,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -776,7 +776,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers.append(layer)
|
self.layers.append(layer)
|
||||||
|
|
||||||
if self.upsampler == 'pixelshuffle_hf':
|
if self.upsampler == 'pixelshuffle_hf':
|
||||||
self.layers_hf = nn.ModuleList()
|
self.layers_hf = nn.ModuleList()
|
||||||
for i_layer in range(self.num_layers):
|
for i_layer in range(self.num_layers):
|
||||||
@ -787,7 +787,7 @@ class Swin2SR(nn.Module):
|
|||||||
num_heads=num_heads[i_layer],
|
num_heads=num_heads[i_layer],
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
mlp_ratio=self.mlp_ratio,
|
mlp_ratio=self.mlp_ratio,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate,
|
drop=drop_rate, attn_drop=attn_drop_rate,
|
||||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
@ -799,7 +799,7 @@ class Swin2SR(nn.Module):
|
|||||||
|
|
||||||
)
|
)
|
||||||
self.layers_hf.append(layer)
|
self.layers_hf.append(layer)
|
||||||
|
|
||||||
self.norm = norm_layer(self.num_features)
|
self.norm = norm_layer(self.num_features)
|
||||||
|
|
||||||
# build the last conv layer in deep feature extraction
|
# build the last conv layer in deep feature extraction
|
||||||
@ -829,10 +829,10 @@ class Swin2SR(nn.Module):
|
|||||||
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
self.conv_after_aux = nn.Sequential(
|
self.conv_after_aux = nn.Sequential(
|
||||||
nn.Conv2d(3, num_feat, 3, 1, 1),
|
nn.Conv2d(3, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.upsample = Upsample(upscale, num_feat)
|
self.upsample = Upsample(upscale, num_feat)
|
||||||
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffle_hf':
|
elif self.upsampler == 'pixelshuffle_hf':
|
||||||
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
@ -846,7 +846,7 @@ class Swin2SR(nn.Module):
|
|||||||
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|
||||||
nn.LeakyReLU(inplace=True))
|
nn.LeakyReLU(inplace=True))
|
||||||
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
||||||
|
|
||||||
elif self.upsampler == 'pixelshuffledirect':
|
elif self.upsampler == 'pixelshuffledirect':
|
||||||
# for lightweight SR (to save parameters)
|
# for lightweight SR (to save parameters)
|
||||||
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
||||||
@ -905,7 +905,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward_features_hf(self, x):
|
def forward_features_hf(self, x):
|
||||||
x_size = (x.shape[2], x.shape[3])
|
x_size = (x.shape[2], x.shape[3])
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
@ -919,7 +919,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.norm(x) # B L C
|
x = self.norm(x) # B L C
|
||||||
x = self.patch_unembed(x, x_size)
|
x = self.patch_unembed(x, x_size)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
H, W = x.shape[2:]
|
H, W = x.shape[2:]
|
||||||
@ -951,7 +951,7 @@ class Swin2SR(nn.Module):
|
|||||||
x = self.conv_after_body(self.forward_features(x)) + x
|
x = self.conv_after_body(self.forward_features(x)) + x
|
||||||
x_before = self.conv_before_upsample(x)
|
x_before = self.conv_before_upsample(x)
|
||||||
x_out = self.conv_last(self.upsample(x_before))
|
x_out = self.conv_last(self.upsample(x_before))
|
||||||
|
|
||||||
x_hf = self.conv_first_hf(x_before)
|
x_hf = self.conv_first_hf(x_before)
|
||||||
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
|
||||||
x_hf = self.conv_before_upsample_hf(x_hf)
|
x_hf = self.conv_before_upsample_hf(x_hf)
|
||||||
@ -977,15 +977,15 @@ class Swin2SR(nn.Module):
|
|||||||
x_first = self.conv_first(x)
|
x_first = self.conv_first(x)
|
||||||
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
||||||
x = x + self.conv_last(res)
|
x = x + self.conv_last(res)
|
||||||
|
|
||||||
x = x / self.img_range + self.mean
|
x = x / self.img_range + self.mean
|
||||||
if self.upsampler == "pixelshuffle_aux":
|
if self.upsampler == "pixelshuffle_aux":
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
return x[:, :, :H*self.upscale, :W*self.upscale], aux
|
||||||
|
|
||||||
elif self.upsampler == "pixelshuffle_hf":
|
elif self.upsampler == "pixelshuffle_hf":
|
||||||
x_out = x_out / self.img_range + self.mean
|
x_out = x_out / self.img_range + self.mean
|
||||||
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return x[:, :, :H*self.upscale, :W*self.upscale]
|
return x[:, :, :H*self.upscale, :W*self.upscale]
|
||||||
|
|
||||||
@ -994,7 +994,7 @@ class Swin2SR(nn.Module):
|
|||||||
H, W = self.patches_resolution
|
H, W = self.patches_resolution
|
||||||
flops += H * W * 3 * self.embed_dim * 9
|
flops += H * W * 3 * self.embed_dim * 9
|
||||||
flops += self.patch_embed.flops()
|
flops += self.patch_embed.flops()
|
||||||
for i, layer in enumerate(self.layers):
|
for layer in self.layers:
|
||||||
flops += layer.flops()
|
flops += layer.flops()
|
||||||
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
flops += H * W * 3 * self.embed_dim * self.embed_dim
|
||||||
flops += self.upsample.flops()
|
flops += self.upsample.flops()
|
||||||
@ -1014,4 +1014,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
x = torch.randn((1, 3, height, width))
|
x = torch.randn((1, 3, height, width))
|
||||||
x = model(x)
|
x = model(x)
|
||||||
print(x.shape)
|
print(x.shape)
|
||||||
|
@ -1,103 +1,42 @@
|
|||||||
// Stable Diffusion WebUI - Bracket checker
|
// Stable Diffusion WebUI - Bracket checker
|
||||||
// Version 1.0
|
// By Hingashi no Florin/Bwin4L & @akx
|
||||||
// By Hingashi no Florin/Bwin4L
|
|
||||||
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
// Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs.
|
||||||
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
// If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong.
|
||||||
|
|
||||||
function checkBrackets(evt, textArea, counterElt) {
|
function checkBrackets(textArea, counterElt) {
|
||||||
errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n';
|
var counts = {};
|
||||||
errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n';
|
(textArea.value.match(/[(){}\[\]]/g) || []).forEach(bracket => {
|
||||||
errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n';
|
counts[bracket] = (counts[bracket] || 0) + 1;
|
||||||
|
});
|
||||||
|
var errors = [];
|
||||||
|
|
||||||
openBracketRegExp = /\(/g;
|
function checkPair(open, close, kind) {
|
||||||
closeBracketRegExp = /\)/g;
|
if (counts[open] !== counts[close]) {
|
||||||
|
errors.push(
|
||||||
openSquareBracketRegExp = /\[/g;
|
`${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.`
|
||||||
closeSquareBracketRegExp = /\]/g;
|
);
|
||||||
|
|
||||||
openCurlyBracketRegExp = /\{/g;
|
|
||||||
closeCurlyBracketRegExp = /\}/g;
|
|
||||||
|
|
||||||
totalOpenBracketMatches = 0;
|
|
||||||
totalCloseBracketMatches = 0;
|
|
||||||
totalOpenSquareBracketMatches = 0;
|
|
||||||
totalCloseSquareBracketMatches = 0;
|
|
||||||
totalOpenCurlyBracketMatches = 0;
|
|
||||||
totalCloseCurlyBracketMatches = 0;
|
|
||||||
|
|
||||||
openBracketMatches = textArea.value.match(openBracketRegExp);
|
|
||||||
if(openBracketMatches) {
|
|
||||||
totalOpenBracketMatches = openBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeBracketMatches = textArea.value.match(closeBracketRegExp);
|
|
||||||
if(closeBracketMatches) {
|
|
||||||
totalCloseBracketMatches = closeBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
openSquareBracketMatches = textArea.value.match(openSquareBracketRegExp);
|
|
||||||
if(openSquareBracketMatches) {
|
|
||||||
totalOpenSquareBracketMatches = openSquareBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeSquareBracketMatches = textArea.value.match(closeSquareBracketRegExp);
|
|
||||||
if(closeSquareBracketMatches) {
|
|
||||||
totalCloseSquareBracketMatches = closeSquareBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
openCurlyBracketMatches = textArea.value.match(openCurlyBracketRegExp);
|
|
||||||
if(openCurlyBracketMatches) {
|
|
||||||
totalOpenCurlyBracketMatches = openCurlyBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
closeCurlyBracketMatches = textArea.value.match(closeCurlyBracketRegExp);
|
|
||||||
if(closeCurlyBracketMatches) {
|
|
||||||
totalCloseCurlyBracketMatches = closeCurlyBracketMatches.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
if(totalOpenBracketMatches != totalCloseBracketMatches) {
|
|
||||||
if(!counterElt.title.includes(errorStringParen)) {
|
|
||||||
counterElt.title += errorStringParen;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
counterElt.title = counterElt.title.replace(errorStringParen, '');
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(totalOpenSquareBracketMatches != totalCloseSquareBracketMatches) {
|
checkPair('(', ')', 'round brackets');
|
||||||
if(!counterElt.title.includes(errorStringSquare)) {
|
checkPair('[', ']', 'square brackets');
|
||||||
counterElt.title += errorStringSquare;
|
checkPair('{', '}', 'curly brackets');
|
||||||
}
|
counterElt.title = errors.join('\n');
|
||||||
} else {
|
counterElt.classList.toggle('error', errors.length !== 0);
|
||||||
counterElt.title = counterElt.title.replace(errorStringSquare, '');
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if(totalOpenCurlyBracketMatches != totalCloseCurlyBracketMatches) {
|
function setupBracketChecking(id_prompt, id_counter) {
|
||||||
if(!counterElt.title.includes(errorStringCurly)) {
|
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
||||||
counterElt.title += errorStringCurly;
|
var counter = gradioApp().getElementById(id_counter)
|
||||||
}
|
|
||||||
} else {
|
|
||||||
counterElt.title = counterElt.title.replace(errorStringCurly, '');
|
|
||||||
}
|
|
||||||
|
|
||||||
if(counterElt.title != '') {
|
if (textarea && counter) {
|
||||||
counterElt.classList.add('error');
|
textarea.addEventListener("input", () => checkBrackets(textarea, counter));
|
||||||
} else {
|
|
||||||
counterElt.classList.remove('error');
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function setupBracketChecking(id_prompt, id_counter){
|
onUiLoaded(function () {
|
||||||
var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea");
|
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter');
|
||||||
var counter = gradioApp().getElementById(id_counter)
|
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter');
|
||||||
|
setupBracketChecking('img2img_prompt', 'img2img_token_counter');
|
||||||
textarea.addEventListener("input", function(evt){
|
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter');
|
||||||
checkBrackets(evt, textarea, counter)
|
});
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
onUiLoaded(function(){
|
|
||||||
setupBracketChecking('txt2img_prompt', 'txt2img_token_counter')
|
|
||||||
setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter')
|
|
||||||
setupBracketChecking('img2img_prompt', 'img2img_token_counter')
|
|
||||||
setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter')
|
|
||||||
})
|
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
<ul>
|
<ul>
|
||||||
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
|
||||||
</ul>
|
</ul>
|
||||||
<span style="display:none" class='search_term'>{search_term}</span>
|
<span style="display:none" class='search_term{search_only}'>{search_term}</span>
|
||||||
</div>
|
</div>
|
||||||
<span class='name'>{name}</span>
|
<span class='name'>{name}</span>
|
||||||
<span class='description'>{description}</span>
|
<span class='description'>{description}</span>
|
||||||
|
@ -661,4 +661,30 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
<h2><a href="https://github.com/madebyollin/taesd/blob/main/LICENSE">TAESD</a></h2>
|
||||||
|
<small>Tiny AutoEncoder for Stable Diffusion option for live previews</small>
|
||||||
|
<pre>
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Ollin Boer Bohan
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
</pre>
|
</pre>
|
@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
|
|||||||
|
|
||||||
var viewportOffset = targetElement.getBoundingClientRect();
|
var viewportOffset = targetElement.getBoundingClientRect();
|
||||||
|
|
||||||
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
|
||||||
|
|
||||||
scaledx = targetElement.naturalWidth*viewportscale
|
var scaledx = targetElement.naturalWidth*viewportscale
|
||||||
scaledy = targetElement.naturalHeight*viewportscale
|
var scaledy = targetElement.naturalHeight*viewportscale
|
||||||
|
|
||||||
cleintRectTop = (viewportOffset.top+window.scrollY)
|
var cleintRectTop = (viewportOffset.top+window.scrollY)
|
||||||
cleintRectLeft = (viewportOffset.left+window.scrollX)
|
var cleintRectLeft = (viewportOffset.left+window.scrollX)
|
||||||
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
|
||||||
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
|
||||||
|
|
||||||
viewRectTop = cleintRectCentreY-(scaledy/2)
|
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
|
||||||
viewRectLeft = cleintRectCentreX-(scaledx/2)
|
var arscaledx = currentWidth*arscale
|
||||||
arRectWidth = scaledx
|
var arscaledy = currentHeight*arscale
|
||||||
arRectHeight = scaledy
|
|
||||||
|
|
||||||
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
|
var arRectTop = cleintRectCentreY-(arscaledy/2)
|
||||||
arscaledx = currentWidth*arscale
|
var arRectLeft = cleintRectCentreX-(arscaledx/2)
|
||||||
arscaledy = currentHeight*arscale
|
var arRectWidth = arscaledx
|
||||||
|
var arRectHeight = arscaledy
|
||||||
arRectTop = cleintRectCentreY-(arscaledy/2)
|
|
||||||
arRectLeft = cleintRectCentreX-(arscaledx/2)
|
|
||||||
arRectWidth = arscaledx
|
|
||||||
arRectHeight = arscaledy
|
|
||||||
|
|
||||||
arPreviewRect.style.top = arRectTop+'px';
|
arPreviewRect.style.top = arRectTop+'px';
|
||||||
arPreviewRect.style.left = arRectLeft+'px';
|
arPreviewRect.style.left = arRectLeft+'px';
|
||||||
|
@ -4,7 +4,7 @@ contextMenuInit = function(){
|
|||||||
let menuSpecs = new Map();
|
let menuSpecs = new Map();
|
||||||
|
|
||||||
const uid = function(){
|
const uid = function(){
|
||||||
return Date.now().toString(36) + Math.random().toString(36).substr(2);
|
return Date.now().toString(36) + Math.random().toString(36).substring(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
function showContextMenu(event,element,menuEntries){
|
function showContextMenu(event,element,menuEntries){
|
||||||
@ -16,8 +16,7 @@ contextMenuInit = function(){
|
|||||||
oldMenu.remove()
|
oldMenu.remove()
|
||||||
}
|
}
|
||||||
|
|
||||||
let tabButton = uiCurrentTab
|
let baseStyle = window.getComputedStyle(uiCurrentTab)
|
||||||
let baseStyle = window.getComputedStyle(tabButton)
|
|
||||||
|
|
||||||
const contextMenu = document.createElement('nav')
|
const contextMenu = document.createElement('nav')
|
||||||
contextMenu.id = "context-menu"
|
contextMenu.id = "context-menu"
|
||||||
@ -36,7 +35,7 @@ contextMenuInit = function(){
|
|||||||
menuEntries.forEach(function(entry){
|
menuEntries.forEach(function(entry){
|
||||||
let contextMenuEntry = document.createElement('a')
|
let contextMenuEntry = document.createElement('a')
|
||||||
contextMenuEntry.innerHTML = entry['name']
|
contextMenuEntry.innerHTML = entry['name']
|
||||||
contextMenuEntry.addEventListener("click", function(e) {
|
contextMenuEntry.addEventListener("click", function() {
|
||||||
entry['func']();
|
entry['func']();
|
||||||
})
|
})
|
||||||
contextMenuList.append(contextMenuEntry);
|
contextMenuList.append(contextMenuEntry);
|
||||||
@ -63,7 +62,7 @@ contextMenuInit = function(){
|
|||||||
|
|
||||||
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
|
||||||
|
|
||||||
currentItems = menuSpecs.get(targetElementSelector)
|
var currentItems = menuSpecs.get(targetElementSelector)
|
||||||
|
|
||||||
if(!currentItems){
|
if(!currentItems){
|
||||||
currentItems = []
|
currentItems = []
|
||||||
@ -79,7 +78,7 @@ contextMenuInit = function(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function removeContextMenuOption(uid){
|
function removeContextMenuOption(uid){
|
||||||
menuSpecs.forEach(function(v,k) {
|
menuSpecs.forEach(function(v) {
|
||||||
let index = -1
|
let index = -1
|
||||||
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
|
||||||
if(index>=0){
|
if(index>=0){
|
||||||
@ -93,8 +92,7 @@ contextMenuInit = function(){
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
gradioApp().addEventListener("click", function(e) {
|
gradioApp().addEventListener("click", function(e) {
|
||||||
let source = e.composedPath()[0]
|
if(! e.isTrusted){
|
||||||
if(source.id && source.id.indexOf('check_progress')>-1){
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -112,7 +110,6 @@ contextMenuInit = function(){
|
|||||||
if(e.composedPath()[0].matches(k)){
|
if(e.composedPath()[0].matches(k)){
|
||||||
showContextMenu(e,e.composedPath()[0],v)
|
showContextMenu(e,e.composedPath()[0],v)
|
||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
@ -161,14 +158,6 @@ addContextMenuEventListener = initResponse[2];
|
|||||||
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#img2img_interrupt','Cancel generate forever',cancelGenerateForever)
|
||||||
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
appendContextMenuOption('#img2img_generate', 'Cancel generate forever',cancelGenerateForever)
|
||||||
|
|
||||||
appendContextMenuOption('#roll','Roll three',
|
|
||||||
function(){
|
|
||||||
let rollbutton = get_uiCurrentTabContent().querySelector('#roll');
|
|
||||||
setTimeout(function(){rollbutton.click()},100)
|
|
||||||
setTimeout(function(){rollbutton.click()},200)
|
|
||||||
setTimeout(function(){rollbutton.click()},300)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
})();
|
})();
|
||||||
//End example Context Menu Items
|
//End example Context Menu Items
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ function keyupEditAttention(event){
|
|||||||
// Find opening parenthesis around current cursor
|
// Find opening parenthesis around current cursor
|
||||||
const before = text.substring(0, selectionStart);
|
const before = text.substring(0, selectionStart);
|
||||||
let beforeParen = before.lastIndexOf(OPEN);
|
let beforeParen = before.lastIndexOf(OPEN);
|
||||||
if (beforeParen == -1) return false;
|
if (beforeParen == -1) return false;
|
||||||
let beforeParenClose = before.lastIndexOf(CLOSE);
|
let beforeParenClose = before.lastIndexOf(CLOSE);
|
||||||
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
while (beforeParenClose !== -1 && beforeParenClose > beforeParen) {
|
||||||
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
beforeParen = before.lastIndexOf(OPEN, beforeParen - 1);
|
||||||
@ -27,7 +27,7 @@ function keyupEditAttention(event){
|
|||||||
// Find closing parenthesis around current cursor
|
// Find closing parenthesis around current cursor
|
||||||
const after = text.substring(selectionStart);
|
const after = text.substring(selectionStart);
|
||||||
let afterParen = after.indexOf(CLOSE);
|
let afterParen = after.indexOf(CLOSE);
|
||||||
if (afterParen == -1) return false;
|
if (afterParen == -1) return false;
|
||||||
let afterParenOpen = after.indexOf(OPEN);
|
let afterParenOpen = after.indexOf(OPEN);
|
||||||
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
while (afterParenOpen !== -1 && afterParen > afterParenOpen) {
|
||||||
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
afterParen = after.indexOf(CLOSE, afterParen + 1);
|
||||||
@ -43,16 +43,34 @@ function keyupEditAttention(event){
|
|||||||
target.setSelectionRange(selectionStart, selectionEnd);
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function selectCurrentWord(){
|
||||||
|
if (selectionStart !== selectionEnd) return false;
|
||||||
|
const delimiters = opts.keyedit_delimiters + " \r\n\t";
|
||||||
|
|
||||||
|
// seek backward until to find beggining
|
||||||
|
while (!delimiters.includes(text[selectionStart - 1]) && selectionStart > 0) {
|
||||||
|
selectionStart--;
|
||||||
|
}
|
||||||
|
|
||||||
|
// seek forward to find end
|
||||||
|
while (!delimiters.includes(text[selectionEnd]) && selectionEnd < text.length) {
|
||||||
|
selectionEnd++;
|
||||||
|
}
|
||||||
|
|
||||||
// If the user hasn't selected anything, let's select their current parenthesis block
|
target.setSelectionRange(selectionStart, selectionEnd);
|
||||||
if(! selectCurrentParenthesisBlock('<', '>')){
|
return true;
|
||||||
selectCurrentParenthesisBlock('(', ')')
|
}
|
||||||
|
|
||||||
|
// If the user hasn't selected anything, let's select their current parenthesis block or word
|
||||||
|
if (!selectCurrentParenthesisBlock('<', '>') && !selectCurrentParenthesisBlock('(', ')')) {
|
||||||
|
selectCurrentWord();
|
||||||
}
|
}
|
||||||
|
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
|
|
||||||
closeCharacter = ')'
|
var closeCharacter = ')'
|
||||||
delta = opts.keyedit_precision_attention
|
var delta = opts.keyedit_precision_attention
|
||||||
|
|
||||||
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
|
||||||
closeCharacter = '>'
|
closeCharacter = '>'
|
||||||
@ -73,15 +91,21 @@ function keyupEditAttention(event){
|
|||||||
selectionEnd += 1;
|
selectionEnd += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
|
||||||
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
|
||||||
if (isNaN(weight)) return;
|
if (isNaN(weight)) return;
|
||||||
|
|
||||||
weight += isPlus ? delta : -delta;
|
weight += isPlus ? delta : -delta;
|
||||||
weight = parseFloat(weight.toPrecision(12));
|
weight = parseFloat(weight.toPrecision(12));
|
||||||
if(String(weight).length == 1) weight += ".0"
|
if(String(weight).length == 1) weight += ".0"
|
||||||
|
|
||||||
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
if (closeCharacter == ')' && weight == 1) {
|
||||||
|
text = text.slice(0, selectionStart - 1) + text.slice(selectionStart, selectionEnd) + text.slice(selectionEnd + 5);
|
||||||
|
selectionStart--;
|
||||||
|
selectionEnd--;
|
||||||
|
} else {
|
||||||
|
text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1);
|
||||||
|
}
|
||||||
|
|
||||||
target.focus();
|
target.focus();
|
||||||
target.value = text;
|
target.value = text;
|
||||||
@ -93,4 +117,4 @@ function keyupEditAttention(event){
|
|||||||
|
|
||||||
addEventListener('keydown', (event) => {
|
addEventListener('keydown', (event) => {
|
||||||
keyupEditAttention(event);
|
keyupEditAttention(event);
|
||||||
});
|
});
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
|
|
||||||
function extensions_apply(_, _, disable_all){
|
function extensions_apply(_disabled_list, _update_list, disable_all){
|
||||||
var disable = []
|
var disable = []
|
||||||
var update = []
|
var update = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
|
|
||||||
if(x.name.startsWith("update_") && x.checked)
|
if(x.name.startsWith("update_") && x.checked)
|
||||||
update.push(x.name.substr(7))
|
update.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
restart_reload()
|
restart_reload()
|
||||||
@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
|
|||||||
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
|
||||||
}
|
}
|
||||||
|
|
||||||
function extensions_check(_, _){
|
function extensions_check(){
|
||||||
var disable = []
|
var disable = []
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
|
||||||
if(x.name.startsWith("enable_") && ! x.checked)
|
if(x.name.startsWith("enable_") && ! x.checked)
|
||||||
disable.push(x.name.substr(7))
|
disable.push(x.name.substring(7))
|
||||||
})
|
})
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
@ -41,9 +41,31 @@ function install_extension_from_index(button, url){
|
|||||||
button.disabled = "disabled"
|
button.disabled = "disabled"
|
||||||
button.value = "Installing..."
|
button.value = "Installing..."
|
||||||
|
|
||||||
textarea = gradioApp().querySelector('#extension_to_install textarea')
|
var textarea = gradioApp().querySelector('#extension_to_install textarea')
|
||||||
textarea.value = url
|
textarea.value = url
|
||||||
updateInput(textarea)
|
updateInput(textarea)
|
||||||
|
|
||||||
gradioApp().querySelector('#install_extension_button').click()
|
gradioApp().querySelector('#install_extension_button').click()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function config_state_confirm_restore(_, config_state_name, config_restore_type) {
|
||||||
|
if (config_state_name == "Current") {
|
||||||
|
return [false, config_state_name, config_restore_type];
|
||||||
|
}
|
||||||
|
let restored = "";
|
||||||
|
if (config_restore_type == "extensions") {
|
||||||
|
restored = "all saved extension versions";
|
||||||
|
} else if (config_restore_type == "webui") {
|
||||||
|
restored = "the webui version";
|
||||||
|
} else {
|
||||||
|
restored = "the webui version and all saved extension versions";
|
||||||
|
}
|
||||||
|
let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".");
|
||||||
|
if (confirmed) {
|
||||||
|
restart_reload();
|
||||||
|
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
|
||||||
|
x.innerHTML = "Loading..."
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return [confirmed, config_state_name, config_restore_type];
|
||||||
|
}
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
function setupExtraNetworksForTab(tabname){
|
function setupExtraNetworksForTab(tabname){
|
||||||
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
|
||||||
|
|
||||||
@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
|
|||||||
tabs.appendChild(search)
|
tabs.appendChild(search)
|
||||||
tabs.appendChild(refresh)
|
tabs.appendChild(refresh)
|
||||||
|
|
||||||
search.addEventListener("input", function(evt){
|
var applyFilter = function(){
|
||||||
searchTerm = search.value.toLowerCase()
|
var searchTerm = search.value.toLowerCase()
|
||||||
|
|
||||||
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
|
||||||
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
var searchOnly = elem.querySelector('.search_only')
|
||||||
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
|
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
|
||||||
|
|
||||||
|
var visible = text.indexOf(searchTerm) != -1
|
||||||
|
|
||||||
|
if(searchOnly && searchTerm.length < 4){
|
||||||
|
visible = false
|
||||||
|
}
|
||||||
|
|
||||||
|
elem.style.display = visible ? "" : "none"
|
||||||
})
|
})
|
||||||
});
|
}
|
||||||
|
|
||||||
|
search.addEventListener("input", applyFilter);
|
||||||
|
applyFilter();
|
||||||
|
|
||||||
|
extraNetworksApplyFilter[tabname] = applyFilter;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyExtraNetworkFilter(tabname){
|
||||||
|
setTimeout(extraNetworksApplyFilter[tabname], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
var extraNetworksApplyFilter = {}
|
||||||
var activePromptTextarea = {};
|
var activePromptTextarea = {};
|
||||||
|
|
||||||
function setupExtraNetworks(){
|
function setupExtraNetworks(){
|
||||||
@ -51,18 +68,27 @@ var re_extranet_g = /\s+<([^:]+:[^:]+):[\d\.]+>/g;
|
|||||||
|
|
||||||
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
function tryToRemoveExtraNetworkFromPrompt(textarea, text){
|
||||||
var m = text.match(re_extranet)
|
var m = text.match(re_extranet)
|
||||||
if(! m) return false
|
|
||||||
|
|
||||||
var partToSearch = m[1]
|
|
||||||
var replaced = false
|
var replaced = false
|
||||||
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
|
var newTextareaText
|
||||||
m = found.match(re_extranet);
|
if(m) {
|
||||||
if(m[1] == partToSearch){
|
var partToSearch = m[1]
|
||||||
replaced = true;
|
newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
|
||||||
return ""
|
m = found.match(re_extranet);
|
||||||
}
|
if(m[1] == partToSearch){
|
||||||
return found;
|
replaced = true;
|
||||||
})
|
return ""
|
||||||
|
}
|
||||||
|
return found;
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
newTextareaText = textarea.value.replaceAll(new RegExp(text, "g"), function(found){
|
||||||
|
if(found == text) {
|
||||||
|
replaced = true;
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return found;
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if(replaced){
|
if(replaced){
|
||||||
textarea.value = newTextareaText
|
textarea.value = newTextareaText
|
||||||
@ -96,9 +122,9 @@ function saveCardPreview(event, tabname, filename){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksSearchButton(tabs_id, event){
|
function extraNetworksSearchButton(tabs_id, event){
|
||||||
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
|
||||||
button = event.target
|
var button = event.target
|
||||||
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
|
||||||
|
|
||||||
searchTextarea.value = text
|
searchTextarea.value = text
|
||||||
updateInput(searchTextarea)
|
updateInput(searchTextarea)
|
||||||
@ -133,7 +159,7 @@ function popup(contents){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksShowMetadata(text){
|
function extraNetworksShowMetadata(text){
|
||||||
elem = document.createElement('pre')
|
var elem = document.createElement('pre')
|
||||||
elem.classList.add('popup-metadata');
|
elem.classList.add('popup-metadata');
|
||||||
elem.textContent = text;
|
elem.textContent = text;
|
||||||
|
|
||||||
@ -165,7 +191,7 @@ function requestGet(url, data, handler, errorHandler){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
function extraNetworksRequestMetadata(event, extraPage, cardName){
|
||||||
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
|
||||||
|
|
||||||
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
|
||||||
if(data && data.metadata){
|
if(data && data.metadata){
|
||||||
|
@ -16,14 +16,14 @@ onUiUpdate(function(){
|
|||||||
|
|
||||||
let modalObserver = new MutationObserver(function(mutations) {
|
let modalObserver = new MutationObserver(function(mutations) {
|
||||||
mutations.forEach(function(mutationRecord) {
|
mutations.forEach(function(mutationRecord) {
|
||||||
let selectedTab = gradioApp().querySelector('#tabs div button.bg-white')?.innerText
|
let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText
|
||||||
if (mutationRecord.target.style.display === 'none' && selectedTab === 'txt2img' || selectedTab === 'img2img')
|
if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img'))
|
||||||
gradioApp().getElementById(selectedTab+"_generation_info_button").click()
|
gradioApp().getElementById(selectedTab+"_generation_info_button")?.click()
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
function attachGalleryListeners(tab_name) {
|
function attachGalleryListeners(tab_name) {
|
||||||
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
|
||||||
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
|
||||||
gallery?.addEventListener('keydown', (e) => {
|
gallery?.addEventListener('keydown', (e) => {
|
||||||
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
|
||||||
|
@ -22,6 +22,7 @@ titles = {
|
|||||||
"\u{1f4cb}": "Apply selected styles to current prompt",
|
"\u{1f4cb}": "Apply selected styles to current prompt",
|
||||||
"\u{1f4d2}": "Paste available values into the field",
|
"\u{1f4d2}": "Paste available values into the field",
|
||||||
"\u{1f3b4}": "Show/hide extra networks",
|
"\u{1f3b4}": "Show/hide extra networks",
|
||||||
|
"\u{1f300}": "Restore progress",
|
||||||
|
|
||||||
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
|
||||||
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
|
||||||
@ -65,8 +66,8 @@ titles = {
|
|||||||
|
|
||||||
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
|
||||||
|
|
||||||
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Images filename pattern": "Use tags like [seed] and [date] to define how filenames for images are chosen. Leave empty for default.",
|
||||||
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg],[prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp]; leave empty for default.",
|
"Directory name pattern": "Use tags like [seed] and [date] to define how subdirectories for images and grids are chosen. Leave empty for default.",
|
||||||
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
|
||||||
|
|
||||||
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
|
||||||
@ -85,7 +86,6 @@ titles = {
|
|||||||
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
|
||||||
|
|
||||||
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
|
||||||
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
|
|
||||||
|
|
||||||
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
|
||||||
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
|
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
|
||||||
@ -111,37 +111,57 @@ titles = {
|
|||||||
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
"Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.",
|
||||||
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
"Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.",
|
||||||
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
"Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.",
|
||||||
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited."
|
"Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited.",
|
||||||
|
"Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function updateTooltipForSpan(span){
|
||||||
|
if (span.title) return; // already has a title
|
||||||
|
|
||||||
onUiUpdate(function(){
|
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
|
||||||
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
|
|
||||||
tooltip = titles[span.textContent];
|
|
||||||
|
|
||||||
if(!tooltip){
|
if(!tooltip){
|
||||||
tooltip = titles[span.value];
|
tooltip = localization[titles[span.value]] || titles[span.value];
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!tooltip){
|
if(!tooltip){
|
||||||
for (const c of span.classList) {
|
for (const c of span.classList) {
|
||||||
if (c in titles) {
|
if (c in titles) {
|
||||||
tooltip = titles[c];
|
tooltip = localization[titles[c]] || titles[c];
|
||||||
break;
|
break;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if(tooltip){
|
if(tooltip){
|
||||||
span.title = tooltip;
|
span.title = tooltip;
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
gradioApp().querySelectorAll('select').forEach(function(select){
|
function updateTooltipForSelect(select){
|
||||||
if (select.onchange != null) return;
|
if (select.onchange != null) return;
|
||||||
|
|
||||||
select.onchange = function(){
|
select.onchange = function(){
|
||||||
select.title = titles[select.value] || "";
|
select.title = localization[titles[select.value]] || titles[select.value] || "";
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
|
observedTooltipElements = {"SPAN": 1, "BUTTON": 1, "SELECT": 1, "P": 1}
|
||||||
|
|
||||||
|
onUiUpdate(function(m){
|
||||||
|
m.forEach(function(record){
|
||||||
|
record.addedNodes.forEach(function(node){
|
||||||
|
if(observedTooltipElements[node.tagName]){
|
||||||
|
updateTooltipForSpan(node)
|
||||||
|
}
|
||||||
|
if(node.tagName == "SELECT"){
|
||||||
|
updateTooltipForSelect(node)
|
||||||
|
}
|
||||||
|
|
||||||
|
if(node.querySelectorAll){
|
||||||
|
node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan)
|
||||||
|
node.querySelectorAll('select').forEach(updateTooltipForSelect)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
@ -1,16 +1,12 @@
|
|||||||
|
|
||||||
function setInactive(elem, inactive){
|
|
||||||
if(inactive){
|
|
||||||
elem.classList.add('inactive')
|
|
||||||
} else{
|
|
||||||
elem.classList.remove('inactive')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
|
||||||
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
function setInactive(elem, inactive){
|
||||||
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
elem.classList.toggle('inactive', !!inactive)
|
||||||
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
}
|
||||||
|
|
||||||
|
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
|
||||||
|
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
|
||||||
|
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
|
||||||
|
|
||||||
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
|
||||||
* @see https://github.com/gradio-app/gradio/issues/1721
|
* @see https://github.com/gradio-app/gradio/issues/1721
|
||||||
*/
|
*/
|
||||||
window.addEventListener( 'resize', () => imageMaskResize());
|
|
||||||
function imageMaskResize() {
|
function imageMaskResize() {
|
||||||
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
|
||||||
if ( ! canvases.length ) {
|
if ( ! canvases.length ) {
|
||||||
canvases_fixed = false;
|
canvases_fixed = false; // TODO: this is unused..?
|
||||||
window.removeEventListener( 'resize', imageMaskResize );
|
window.removeEventListener( 'resize', imageMaskResize );
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -15,7 +14,7 @@ function imageMaskResize() {
|
|||||||
const previewImage = wrapper.previousElementSibling;
|
const previewImage = wrapper.previousElementSibling;
|
||||||
|
|
||||||
if ( ! previewImage.complete ) {
|
if ( ! previewImage.complete ) {
|
||||||
previewImage.addEventListener( 'load', () => imageMaskResize());
|
previewImage.addEventListener( 'load', imageMaskResize);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,7 +23,6 @@ function imageMaskResize() {
|
|||||||
const nw = previewImage.naturalWidth;
|
const nw = previewImage.naturalWidth;
|
||||||
const nh = previewImage.naturalHeight;
|
const nh = previewImage.naturalHeight;
|
||||||
const portrait = nh > nw;
|
const portrait = nh > nw;
|
||||||
const factor = portrait;
|
|
||||||
|
|
||||||
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
|
||||||
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
|
||||||
@ -40,6 +38,7 @@ function imageMaskResize() {
|
|||||||
c.style.maxHeight = '100%';
|
c.style.maxHeight = '100%';
|
||||||
c.style.objectFit = 'contain';
|
c.style.objectFit = 'contain';
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(() => imageMaskResize());
|
onUiUpdate(imageMaskResize);
|
||||||
|
window.addEventListener( 'resize', imageMaskResize);
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
window.onload = (function(){
|
window.onload = (function(){
|
||||||
window.addEventListener('drop', e => {
|
window.addEventListener('drop', e => {
|
||||||
const target = e.composedPath()[0];
|
const target = e.composedPath()[0];
|
||||||
const idx = selected_gallery_index();
|
|
||||||
if (target.placeholder.indexOf("Prompt") == -1) return;
|
if (target.placeholder.indexOf("Prompt") == -1) return;
|
||||||
|
|
||||||
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
|
||||||
|
@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (result != -1) {
|
if (result != -1) {
|
||||||
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
|
||||||
nextButton.click()
|
nextButton.click()
|
||||||
const modalImage = gradioApp().getElementById("modalImage");
|
const modalImage = gradioApp().getElementById("modalImage");
|
||||||
const modal = gradioApp().getElementById("lightboxModal");
|
const modal = gradioApp().getElementById("lightboxModal");
|
||||||
@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomSet(modalImage, enable) {
|
function modalZoomSet(modalImage, enable) {
|
||||||
if (enable) {
|
if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
|
||||||
modalImage.classList.add('modalImageFullscreen');
|
|
||||||
} else {
|
|
||||||
modalImage.classList.remove('modalImageFullscreen');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function modalZoomToggle(event) {
|
function modalZoomToggle(event) {
|
||||||
modalImage = gradioApp().getElementById("modalImage");
|
var modalImage = gradioApp().getElementById("modalImage");
|
||||||
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
}
|
}
|
||||||
@ -179,7 +175,7 @@ function galleryImageHandler(e) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function() {
|
onUiUpdate(function() {
|
||||||
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
|
||||||
if (fullImg_preview != null) {
|
if (fullImg_preview != null) {
|
||||||
fullImg_preview.forEach(setupImageForLightbox);
|
fullImg_preview.forEach(setupImageForLightbox);
|
||||||
}
|
}
|
||||||
@ -251,8 +247,11 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
|
|
||||||
modal.appendChild(modalNext)
|
modal.appendChild(modalNext)
|
||||||
|
|
||||||
gradioApp().appendChild(modal)
|
try {
|
||||||
|
gradioApp().appendChild(modal);
|
||||||
|
} catch (e) {
|
||||||
|
gradioApp().body.appendChild(modal);
|
||||||
|
}
|
||||||
|
|
||||||
document.body.appendChild(modal);
|
document.body.appendChild(modal);
|
||||||
|
|
||||||
|
57
javascript/imageviewerGamepad.js
Normal file
57
javascript/imageviewerGamepad.js
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
window.addEventListener('gamepadconnected', (e) => {
|
||||||
|
const index = e.gamepad.index;
|
||||||
|
let isWaiting = false;
|
||||||
|
setInterval(async () => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
|
||||||
|
const gamepad = navigator.getGamepads()[index];
|
||||||
|
const xValue = gamepad.axes[0];
|
||||||
|
if (xValue <= -0.3) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
isWaiting = true;
|
||||||
|
} else if (xValue >= 0.3) {
|
||||||
|
modalNextImage(e);
|
||||||
|
isWaiting = true;
|
||||||
|
}
|
||||||
|
if (isWaiting) {
|
||||||
|
await sleepUntil(() => {
|
||||||
|
const xValue = navigator.getGamepads()[index].axes[0]
|
||||||
|
if (xValue < 0.3 && xValue > -0.3) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
isWaiting = false;
|
||||||
|
}
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
|
||||||
|
/*
|
||||||
|
Primarily for vr controller type pointer devices.
|
||||||
|
I use the wheel event because there's currently no way to do it properly with web xr.
|
||||||
|
*/
|
||||||
|
let isScrolling = false;
|
||||||
|
window.addEventListener('wheel', (e) => {
|
||||||
|
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
|
||||||
|
isScrolling = true;
|
||||||
|
|
||||||
|
if (e.deltaX <= -0.6) {
|
||||||
|
modalPrevImage(e);
|
||||||
|
} else if (e.deltaX >= 0.6) {
|
||||||
|
modalNextImage(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
isScrolling = false;
|
||||||
|
}, opts.js_modal_lightbox_gamepad_repeat);
|
||||||
|
});
|
||||||
|
|
||||||
|
function sleepUntil(f, timeout) {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
const timeStart = new Date();
|
||||||
|
const wait = setInterval(function() {
|
||||||
|
if (f() || new Date() - timeStart > timeout) {
|
||||||
|
clearInterval(wait);
|
||||||
|
resolve();
|
||||||
|
}
|
||||||
|
}, 20);
|
||||||
|
});
|
||||||
|
}
|
@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
|
|||||||
original_lines = {}
|
original_lines = {}
|
||||||
translated_lines = {}
|
translated_lines = {}
|
||||||
|
|
||||||
|
function hasLocalization() {
|
||||||
|
return window.localization && Object.keys(window.localization).length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
function textNodesUnder(el){
|
function textNodesUnder(el){
|
||||||
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
|
||||||
while(n=walk.nextNode()) a.push(n);
|
while(n=walk.nextNode()) a.push(n);
|
||||||
@ -35,11 +39,11 @@ function canBeTranslated(node, text){
|
|||||||
if(! text) return false;
|
if(! text) return false;
|
||||||
if(! node.parentElement) return false;
|
if(! node.parentElement) return false;
|
||||||
|
|
||||||
parentType = node.parentElement.nodeName
|
var parentType = node.parentElement.nodeName
|
||||||
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
|
||||||
|
|
||||||
if (parentType=='OPTION' || parentType=='SPAN'){
|
if (parentType=='OPTION' || parentType=='SPAN'){
|
||||||
pnode = node
|
var pnode = node
|
||||||
for(var level=0; level<4; level++){
|
for(var level=0; level<4; level++){
|
||||||
pnode = pnode.parentElement
|
pnode = pnode.parentElement
|
||||||
if(! pnode) break;
|
if(! pnode) break;
|
||||||
@ -69,7 +73,7 @@ function getTranslation(text){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function processTextNode(node){
|
function processTextNode(node){
|
||||||
text = node.textContent.trim()
|
var text = node.textContent.trim()
|
||||||
|
|
||||||
if(! canBeTranslated(node, text)) return
|
if(! canBeTranslated(node, text)) return
|
||||||
|
|
||||||
@ -105,30 +109,52 @@ function processNode(node){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function dumpTranslations(){
|
function dumpTranslations(){
|
||||||
dumped = {}
|
if(!hasLocalization()) {
|
||||||
|
// If we don't have any localization,
|
||||||
|
// we will not have traversed the app to find
|
||||||
|
// original_lines, so do that now.
|
||||||
|
processNode(gradioApp());
|
||||||
|
}
|
||||||
|
var dumped = {}
|
||||||
if (localization.rtl) {
|
if (localization.rtl) {
|
||||||
dumped.rtl = true
|
dumped.rtl = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
Object.keys(original_lines).forEach(function(text){
|
for (const text in original_lines) {
|
||||||
if(dumped[text] !== undefined) return
|
if(dumped[text] !== undefined) continue;
|
||||||
|
dumped[text] = localization[text] || text;
|
||||||
|
}
|
||||||
|
|
||||||
dumped[text] = localization[text] || text
|
return dumped;
|
||||||
})
|
|
||||||
|
|
||||||
return dumped
|
|
||||||
}
|
}
|
||||||
|
|
||||||
onUiUpdate(function(m){
|
function download_localization() {
|
||||||
m.forEach(function(mutation){
|
var text = JSON.stringify(dumpTranslations(), null, 4)
|
||||||
mutation.addedNodes.forEach(function(node){
|
|
||||||
processNode(node)
|
|
||||||
})
|
|
||||||
});
|
|
||||||
})
|
|
||||||
|
|
||||||
|
var element = document.createElement('a');
|
||||||
|
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
||||||
|
element.setAttribute('download', "localization.json");
|
||||||
|
element.style.display = 'none';
|
||||||
|
document.body.appendChild(element);
|
||||||
|
|
||||||
|
element.click();
|
||||||
|
|
||||||
|
document.body.removeChild(element);
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
if (!hasLocalization()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
onUiUpdate(function (m) {
|
||||||
|
m.forEach(function (mutation) {
|
||||||
|
mutation.addedNodes.forEach(function (node) {
|
||||||
|
processNode(node)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
})
|
||||||
|
|
||||||
document.addEventListener("DOMContentLoaded", function() {
|
|
||||||
processNode(gradioApp())
|
processNode(gradioApp())
|
||||||
|
|
||||||
if (localization.rtl) { // if the language is from right to left,
|
if (localization.rtl) { // if the language is from right to left,
|
||||||
@ -149,17 +175,3 @@ document.addEventListener("DOMContentLoaded", function() {
|
|||||||
})).observe(gradioApp(), { childList: true });
|
})).observe(gradioApp(), { childList: true });
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
function download_localization() {
|
|
||||||
text = JSON.stringify(dumpTranslations(), null, 4)
|
|
||||||
|
|
||||||
var element = document.createElement('a');
|
|
||||||
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
|
|
||||||
element.setAttribute('download', "localization.json");
|
|
||||||
element.style.display = 'none';
|
|
||||||
document.body.appendChild(element);
|
|
||||||
|
|
||||||
element.click();
|
|
||||||
|
|
||||||
document.body.removeChild(element);
|
|
||||||
}
|
|
||||||
|
@ -2,15 +2,15 @@
|
|||||||
|
|
||||||
let lastHeadImg = null;
|
let lastHeadImg = null;
|
||||||
|
|
||||||
notificationButton = null
|
let notificationButton = null;
|
||||||
|
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(notificationButton == null){
|
if(notificationButton == null){
|
||||||
notificationButton = gradioApp().getElementById('request_notifications')
|
notificationButton = gradioApp().getElementById('request_notifications')
|
||||||
|
|
||||||
if(notificationButton != null){
|
if(notificationButton != null){
|
||||||
notificationButton.addEventListener('click', function (evt) {
|
notificationButton.addEventListener('click', () => {
|
||||||
Notification.requestPermission();
|
void Notification.requestPermission();
|
||||||
},true);
|
},true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,16 +1,15 @@
|
|||||||
// code related to showing and updating progressbar shown as the image is being made
|
// code related to showing and updating progressbar shown as the image is being made
|
||||||
|
|
||||||
function rememberGallerySelection(id_gallery){
|
function rememberGallerySelection(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGallerySelectedIndex(id_gallery){
|
function getGallerySelectedIndex(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function request(url, data, handler, errorHandler){
|
function request(url, data, handler, errorHandler){
|
||||||
var xhr = new XMLHttpRequest();
|
var xhr = new XMLHttpRequest();
|
||||||
var url = url;
|
|
||||||
xhr.open("POST", url, true);
|
xhr.open("POST", url, true);
|
||||||
xhr.setRequestHeader("Content-Type", "application/json");
|
xhr.setRequestHeader("Content-Type", "application/json");
|
||||||
xhr.onreadystatechange = function () {
|
xhr.onreadystatechange = function () {
|
||||||
@ -66,7 +65,7 @@ function randomId(){
|
|||||||
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and
|
||||||
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd.
|
||||||
// calls onProgress every time there is a progress update
|
// calls onProgress every time there is a progress update
|
||||||
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){
|
function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout=40){
|
||||||
var dateStart = new Date()
|
var dateStart = new Date()
|
||||||
var wasEverActive = false
|
var wasEverActive = false
|
||||||
var parentProgressbar = progressbarContainer.parentNode
|
var parentProgressbar = progressbarContainer.parentNode
|
||||||
@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
divProgress.style.width = rect.width + "px";
|
divProgress.style.width = rect.width + "px";
|
||||||
}
|
}
|
||||||
|
|
||||||
progressText = ""
|
let progressText = ""
|
||||||
|
|
||||||
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
|
||||||
divInner.style.background = res.progress ? "" : "transparent"
|
divInner.style.background = res.progress ? "" : "transparent"
|
||||||
@ -138,7 +137,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if(elapsedFromStart > 5 && !res.queued && !res.active){
|
if(elapsedFromStart > inactivityTimeout && !res.queued && !res.active){
|
||||||
removeProgressBar()
|
removeProgressBar()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
108
javascript/ui.js
108
javascript/ui.js
@ -1,7 +1,7 @@
|
|||||||
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
|
||||||
|
|
||||||
function set_theme(theme){
|
function set_theme(theme){
|
||||||
gradioURL = window.location.href
|
var gradioURL = window.location.href
|
||||||
if (!gradioURL.includes('?__theme=')) {
|
if (!gradioURL.includes('?__theme=')) {
|
||||||
window.location.replace(gradioURL + '?__theme=' + theme);
|
window.location.replace(gradioURL + '?__theme=' + theme);
|
||||||
}
|
}
|
||||||
@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
return [gallery[0]];
|
return [gallery[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
index = selected_gallery_index()
|
var index = selected_gallery_index()
|
||||||
|
|
||||||
if (index < 0 || index >= gallery.length){
|
if (index < 0 || index >= gallery.length){
|
||||||
// Use the first image in the gallery as the default
|
// Use the first image in the gallery as the default
|
||||||
@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
|
|||||||
}
|
}
|
||||||
|
|
||||||
function args_to_array(args){
|
function args_to_array(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -138,7 +138,7 @@ function get_img2img_tab_index() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function create_submit_args(args){
|
function create_submit_args(args){
|
||||||
res = []
|
var res = []
|
||||||
for(var i=0;i<args.length;i++){
|
for(var i=0;i<args.length;i++){
|
||||||
res.push(args[i])
|
res.push(args[i])
|
||||||
}
|
}
|
||||||
@ -159,14 +159,24 @@ function showSubmitButtons(tabname, show){
|
|||||||
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function showRestoreProgressButton(tabname, show){
|
||||||
|
var button = gradioApp().getElementById(tabname + "_restore_progress")
|
||||||
|
if(! button) return
|
||||||
|
|
||||||
|
button.style.display = show ? "flex" : "none"
|
||||||
|
}
|
||||||
|
|
||||||
function submit(){
|
function submit(){
|
||||||
rememberGallerySelection('txt2img_gallery')
|
rememberGallerySelection('txt2img_gallery')
|
||||||
showSubmitButtons('txt2img', false)
|
showSubmitButtons('txt2img', false)
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
|
localStorage.setItem("txt2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||||
showSubmitButtons('txt2img', true)
|
showSubmitButtons('txt2img', true)
|
||||||
|
localStorage.removeItem("txt2img_task_id")
|
||||||
|
showRestoreProgressButton('txt2img', false)
|
||||||
})
|
})
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments)
|
||||||
@ -181,8 +191,12 @@ function submit_img2img(){
|
|||||||
showSubmitButtons('img2img', false)
|
showSubmitButtons('img2img', false)
|
||||||
|
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
|
localStorage.setItem("img2img_task_id", id);
|
||||||
|
|
||||||
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||||
showSubmitButtons('img2img', true)
|
showSubmitButtons('img2img', true)
|
||||||
|
localStorage.removeItem("img2img_task_id")
|
||||||
|
showRestoreProgressButton('img2img', false)
|
||||||
})
|
})
|
||||||
|
|
||||||
var res = create_submit_args(arguments)
|
var res = create_submit_args(arguments)
|
||||||
@ -193,6 +207,42 @@ function submit_img2img(){
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function restoreProgressTxt2img(){
|
||||||
|
showRestoreProgressButton("txt2img", false)
|
||||||
|
var id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
|
id = localStorage.getItem("txt2img_task_id")
|
||||||
|
|
||||||
|
if(id) {
|
||||||
|
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){
|
||||||
|
showSubmitButtons('txt2img', true)
|
||||||
|
}, null, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
function restoreProgressImg2img(){
|
||||||
|
showRestoreProgressButton("img2img", false)
|
||||||
|
|
||||||
|
var id = localStorage.getItem("img2img_task_id")
|
||||||
|
|
||||||
|
if(id) {
|
||||||
|
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
|
||||||
|
showSubmitButtons('img2img', true)
|
||||||
|
}, null, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
onUiLoaded(function () {
|
||||||
|
showRestoreProgressButton('txt2img', localStorage.getItem("txt2img_task_id"))
|
||||||
|
showRestoreProgressButton('img2img', localStorage.getItem("img2img_task_id"))
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
function modelmerger(){
|
function modelmerger(){
|
||||||
var id = randomId()
|
var id = randomId()
|
||||||
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})
|
||||||
@ -204,7 +254,7 @@ function modelmerger(){
|
|||||||
|
|
||||||
|
|
||||||
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
|
||||||
name_ = prompt('Style name:')
|
var name_ = prompt('Style name:')
|
||||||
return [name_, prompt_text, negative_prompt_text]
|
return [name_, prompt_text, negative_prompt_text]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,11 +289,11 @@ function recalculate_prompts_img2img(){
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
opts = {}
|
var opts = {}
|
||||||
onUiUpdate(function(){
|
onUiUpdate(function(){
|
||||||
if(Object.keys(opts).length != 0) return;
|
if(Object.keys(opts).length != 0) return;
|
||||||
|
|
||||||
json_elem = gradioApp().getElementById('settings_json')
|
var json_elem = gradioApp().getElementById('settings_json')
|
||||||
if(json_elem == null) return;
|
if(json_elem == null) return;
|
||||||
|
|
||||||
var textarea = json_elem.querySelector('textarea')
|
var textarea = json_elem.querySelector('textarea')
|
||||||
@ -292,12 +342,15 @@ onUiUpdate(function(){
|
|||||||
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
|
||||||
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
|
||||||
|
|
||||||
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
|
||||||
settings_tabs = gradioApp().querySelector('#settings div')
|
var settings_tabs = gradioApp().querySelector('#settings div')
|
||||||
if(show_all_pages && settings_tabs){
|
if(show_all_pages && settings_tabs){
|
||||||
settings_tabs.appendChild(show_all_pages)
|
settings_tabs.appendChild(show_all_pages)
|
||||||
show_all_pages.onclick = function(){
|
show_all_pages.onclick = function(){
|
||||||
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
|
||||||
|
if(elem.id == "settings_tab_licenses")
|
||||||
|
return;
|
||||||
|
|
||||||
elem.style.display = "block";
|
elem.style.display = "block";
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -305,9 +358,9 @@ onUiUpdate(function(){
|
|||||||
})
|
})
|
||||||
|
|
||||||
onOptionsChanged(function(){
|
onOptionsChanged(function(){
|
||||||
elem = gradioApp().getElementById('sd_checkpoint_hash')
|
var elem = gradioApp().getElementById('sd_checkpoint_hash')
|
||||||
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
|
||||||
shorthash = sd_checkpoint_hash.substr(0,10)
|
var shorthash = sd_checkpoint_hash.substring(0,10)
|
||||||
|
|
||||||
if(elem && elem.textContent != shorthash){
|
if(elem && elem.textContent != shorthash){
|
||||||
elem.textContent = shorthash
|
elem.textContent = shorthash
|
||||||
@ -342,7 +395,16 @@ function update_token_counter(button_id) {
|
|||||||
|
|
||||||
function restart_reload(){
|
function restart_reload(){
|
||||||
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
|
||||||
setTimeout(function(){location.reload()},2000)
|
|
||||||
|
var requestPing = function(){
|
||||||
|
requestGet("./internal/ping", {}, function(data){
|
||||||
|
location.reload();
|
||||||
|
}, function(){
|
||||||
|
setTimeout(requestPing, 500);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setTimeout(requestPing, 2000);
|
||||||
|
|
||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
@ -362,6 +424,23 @@ function selectCheckpoint(name){
|
|||||||
gradioApp().getElementById('change_checkpoint').click()
|
gradioApp().getElementById('change_checkpoint').click()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function currentImg2imgSourceResolution(_, _, scaleBy){
|
||||||
|
var img = gradioApp().querySelector('#mode_img2img > div[style="display: block;"] img')
|
||||||
|
return img ? [img.naturalWidth, img.naturalHeight, scaleBy] : [0, 0, scaleBy]
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateImg2imgResizeToTextAfterChangingImage(){
|
||||||
|
// At the time this is called from gradio, the image has no yet been replaced.
|
||||||
|
// There may be a better solution, but this is simple and straightforward so I'm going with it.
|
||||||
|
|
||||||
|
setTimeout(function() {
|
||||||
|
gradioApp().getElementById('img2img_update_resize_to').click()
|
||||||
|
}, 500);
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
function setRandomSeed(target_interface) {
|
function setRandomSeed(target_interface) {
|
||||||
let seed = gradioApp().querySelector(`#${target_interface}_seed input`);
|
let seed = gradioApp().querySelector(`#${target_interface}_seed input`);
|
||||||
if (!seed) {
|
if (!seed) {
|
||||||
@ -409,3 +488,4 @@ function switchWidthHeightImg2Img() {
|
|||||||
height.dispatchEvent(new Event("input"));
|
height.dispatchEvent(new Event("input"));
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
62
javascript/ui_settings_hints.js
Normal file
62
javascript/ui_settings_hints.js
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
// various hints and extra info for the settings tab
|
||||||
|
|
||||||
|
settingsHintsSetup = false
|
||||||
|
|
||||||
|
onOptionsChanged(function(){
|
||||||
|
if(settingsHintsSetup) return
|
||||||
|
settingsHintsSetup = true
|
||||||
|
|
||||||
|
gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div){
|
||||||
|
var name = div.id.substr(8)
|
||||||
|
var commentBefore = opts._comments_before[name]
|
||||||
|
var commentAfter = opts._comments_after[name]
|
||||||
|
|
||||||
|
if(! commentBefore && !commentAfter) return
|
||||||
|
|
||||||
|
var span = null
|
||||||
|
if(div.classList.contains('gradio-checkbox')) span = div.querySelector('label span')
|
||||||
|
else if(div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild
|
||||||
|
else if(div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild
|
||||||
|
else span = div.querySelector('label span').firstChild
|
||||||
|
|
||||||
|
if(!span) return
|
||||||
|
|
||||||
|
if(commentBefore){
|
||||||
|
var comment = document.createElement('DIV')
|
||||||
|
comment.className = 'settings-comment'
|
||||||
|
comment.innerHTML = commentBefore
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
|
||||||
|
span.parentElement.insertBefore(comment, span)
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span)
|
||||||
|
}
|
||||||
|
if(commentAfter){
|
||||||
|
var comment = document.createElement('DIV')
|
||||||
|
comment.className = 'settings-comment'
|
||||||
|
comment.innerHTML = commentAfter
|
||||||
|
span.parentElement.insertBefore(comment, span.nextSibling)
|
||||||
|
span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
function settingsHintsShowQuicksettings(){
|
||||||
|
requestGet("./internal/quicksettings-hint", {}, function(data){
|
||||||
|
var table = document.createElement('table')
|
||||||
|
table.className = 'settings-value-table'
|
||||||
|
|
||||||
|
data.forEach(function(obj){
|
||||||
|
var tr = document.createElement('tr')
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.name
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
var td = document.createElement('td')
|
||||||
|
td.textContent = obj.label
|
||||||
|
tr.appendChild(td)
|
||||||
|
|
||||||
|
table.appendChild(tr)
|
||||||
|
})
|
||||||
|
|
||||||
|
popup(table);
|
||||||
|
})
|
||||||
|
}
|
120
launch.py
120
launch.py
@ -3,25 +3,23 @@ import subprocess
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shlex
|
|
||||||
import platform
|
import platform
|
||||||
import json
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
from modules import cmd_args
|
from modules import cmd_args
|
||||||
from modules.paths_internal import script_path, extensions_dir
|
from modules.paths_internal import script_path, extensions_dir
|
||||||
|
|
||||||
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
|
||||||
sys.argv += shlex.split(commandline_args)
|
|
||||||
|
|
||||||
args, _ = cmd_args.parser.parse_known_args()
|
args, _ = cmd_args.parser.parse_known_args()
|
||||||
|
|
||||||
python = sys.executable
|
python = sys.executable
|
||||||
git = os.environ.get('GIT', "git")
|
git = os.environ.get('GIT', "git")
|
||||||
index_url = os.environ.get('INDEX_URL', "")
|
index_url = os.environ.get('INDEX_URL', "")
|
||||||
stored_commit_hash = None
|
|
||||||
skip_install = False
|
|
||||||
dir_repos = "repositories"
|
dir_repos = "repositories"
|
||||||
|
|
||||||
|
# Whether to default to printing command output
|
||||||
|
default_command_live = (os.environ.get('WEBUI_LAUNCH_LIVE_OUTPUT') == "1")
|
||||||
|
|
||||||
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
||||||
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
||||||
|
|
||||||
@ -49,7 +47,7 @@ or any other error regarding unsuccessful package (library) installation,
|
|||||||
please downgrade (or upgrade) to the latest version of 3.10 Python
|
please downgrade (or upgrade) to the latest version of 3.10 Python
|
||||||
and delete current Python and "venv" folder in WebUI's directory.
|
and delete current Python and "venv" folder in WebUI's directory.
|
||||||
|
|
||||||
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3106/
|
||||||
|
|
||||||
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
||||||
|
|
||||||
@ -57,51 +55,52 @@ Use --skip-python-version-check to suppress this warning.
|
|||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
def commit_hash():
|
def commit_hash():
|
||||||
global stored_commit_hash
|
|
||||||
|
|
||||||
if stored_commit_hash is not None:
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
||||||
except Exception:
|
except Exception:
|
||||||
stored_commit_hash = "<none>"
|
return "<none>"
|
||||||
|
|
||||||
return stored_commit_hash
|
|
||||||
|
|
||||||
|
|
||||||
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
@lru_cache()
|
||||||
|
def git_tag():
|
||||||
|
try:
|
||||||
|
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
||||||
|
except Exception:
|
||||||
|
return "<none>"
|
||||||
|
|
||||||
|
|
||||||
|
def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str:
|
||||||
if desc is not None:
|
if desc is not None:
|
||||||
print(desc)
|
print(desc)
|
||||||
|
|
||||||
if live:
|
run_kwargs = {
|
||||||
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
"args": command,
|
||||||
if result.returncode != 0:
|
"shell": True,
|
||||||
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
"env": os.environ if custom_env is None else custom_env,
|
||||||
Command: {command}
|
"encoding": 'utf8',
|
||||||
Error code: {result.returncode}""")
|
"errors": 'ignore',
|
||||||
|
}
|
||||||
|
|
||||||
return ""
|
if not live:
|
||||||
|
run_kwargs["stdout"] = run_kwargs["stderr"] = subprocess.PIPE
|
||||||
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
result = subprocess.run(**run_kwargs)
|
||||||
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
|
error_bits = [
|
||||||
|
f"{errdesc or 'Error running command'}.",
|
||||||
|
f"Command: {command}",
|
||||||
|
f"Error code: {result.returncode}",
|
||||||
|
]
|
||||||
|
if result.stdout:
|
||||||
|
error_bits.append(f"stdout: {result.stdout}")
|
||||||
|
if result.stderr:
|
||||||
|
error_bits.append(f"stderr: {result.stderr}")
|
||||||
|
raise RuntimeError("\n".join(error_bits))
|
||||||
|
|
||||||
message = f"""{errdesc or 'Error running command'}.
|
return (result.stdout or "")
|
||||||
Command: {command}
|
|
||||||
Error code: {result.returncode}
|
|
||||||
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
|
||||||
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
|
||||||
"""
|
|
||||||
raise RuntimeError(message)
|
|
||||||
|
|
||||||
return result.stdout.decode(encoding="utf8", errors="ignore")
|
|
||||||
|
|
||||||
|
|
||||||
def check_run(command):
|
|
||||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
|
||||||
return result.returncode == 0
|
|
||||||
|
|
||||||
|
|
||||||
def is_installed(package):
|
def is_installed(package):
|
||||||
@ -117,20 +116,17 @@ def repo_dir(name):
|
|||||||
return os.path.join(script_path, dir_repos, name)
|
return os.path.join(script_path, dir_repos, name)
|
||||||
|
|
||||||
|
|
||||||
def run_python(code, desc=None, errdesc=None):
|
def run_pip(command, desc=None, live=default_command_live):
|
||||||
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
if args.skip_install:
|
||||||
|
|
||||||
|
|
||||||
def run_pip(args, desc=None):
|
|
||||||
if skip_install:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
||||||
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
return run(f'"{python}" -m pip {command} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)
|
||||||
|
|
||||||
|
|
||||||
def check_run_python(code):
|
def check_run_python(code: str) -> bool:
|
||||||
return check_run(f'"{python}" -c "{code}"')
|
result = subprocess.run([python, "-c", code], capture_output=True, shell=False)
|
||||||
|
return result.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
def git_clone(url, dir, name, commithash=None):
|
def git_clone(url, dir, name, commithash=None):
|
||||||
@ -223,15 +219,14 @@ def run_extensions_installers(settings_file):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_environment():
|
def prepare_environment():
|
||||||
global skip_install
|
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
|
||||||
|
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
|
||||||
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
|
|
||||||
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
|
||||||
|
|
||||||
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
|
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
|
||||||
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
|
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip")
|
||||||
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
|
clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
|
||||||
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b")
|
openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")
|
||||||
|
|
||||||
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")
|
||||||
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git")
|
||||||
@ -249,15 +244,20 @@ def prepare_environment():
|
|||||||
check_python_version()
|
check_python_version()
|
||||||
|
|
||||||
commit = commit_hash()
|
commit = commit_hash()
|
||||||
|
tag = git_tag()
|
||||||
|
|
||||||
print(f"Python {sys.version}")
|
print(f"Python {sys.version}")
|
||||||
|
print(f"Version: {tag}")
|
||||||
print(f"Commit hash: {commit}")
|
print(f"Commit hash: {commit}")
|
||||||
|
|
||||||
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
|
||||||
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
||||||
|
|
||||||
if not args.skip_torch_cuda_test:
|
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
|
||||||
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
|
raise RuntimeError(
|
||||||
|
'Torch is not able to use GPU; '
|
||||||
|
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
|
||||||
|
)
|
||||||
|
|
||||||
if not is_installed("gfpgan"):
|
if not is_installed("gfpgan"):
|
||||||
run_pip(f"install {gfpgan_package}", "gfpgan")
|
run_pip(f"install {gfpgan_package}", "gfpgan")
|
||||||
@ -271,7 +271,7 @@ def prepare_environment():
|
|||||||
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
if platform.python_version().startswith("3.10"):
|
if platform.python_version().startswith("3.10"):
|
||||||
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
|
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True)
|
||||||
else:
|
else:
|
||||||
print("Installation of xformers is not supported in this version of Python.")
|
print("Installation of xformers is not supported in this version of Python.")
|
||||||
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness")
|
||||||
@ -296,7 +296,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if not os.path.isfile(requirements_file):
|
if not os.path.isfile(requirements_file):
|
||||||
requirements_file = os.path.join(script_path, requirements_file)
|
requirements_file = os.path.join(script_path, requirements_file)
|
||||||
run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI")
|
run_pip(f"install -r \"{requirements_file}\"", "requirements")
|
||||||
|
|
||||||
run_extensions_installers(settings_file=args.ui_settings_file)
|
run_extensions_installers(settings_file=args.ui_settings_file)
|
||||||
|
|
||||||
@ -305,7 +305,7 @@ def prepare_environment():
|
|||||||
|
|
||||||
if args.update_all_extensions:
|
if args.update_all_extensions:
|
||||||
git_pull_recursive(extensions_dir)
|
git_pull_recursive(extensions_dir)
|
||||||
|
|
||||||
if "--exit" in sys.argv:
|
if "--exit" in sys.argv:
|
||||||
print("Exiting because of --exit argument")
|
print("Exiting because of --exit argument")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
BIN
modules/Roboto-Regular.ttf
Normal file
BIN
modules/Roboto-Regular.ttf
Normal file
Binary file not shown.
@ -6,7 +6,6 @@ import uvicorn
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from gradio.processing_utils import decode_base64_to_file
|
|
||||||
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
from fastapi import APIRouter, Depends, FastAPI, Request, Response
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
@ -16,7 +15,8 @@ from secrets import compare_digest
|
|||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
|
||||||
from modules.api.models import *
|
from modules.api import models
|
||||||
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
|
||||||
from modules.textual_inversion.preprocess import preprocess
|
from modules.textual_inversion.preprocess import preprocess
|
||||||
@ -26,21 +26,24 @@ from modules.sd_models import checkpoints_list, unload_model_weights, reload_mod
|
|||||||
from modules.sd_models_config import find_checkpoint_config_near_filename
|
from modules.sd_models_config import find_checkpoint_config_near_filename
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from modules import devices
|
from modules import devices
|
||||||
from typing import List
|
from typing import Dict, List, Any
|
||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
|
|
||||||
|
|
||||||
def upscaler_to_index(name: str):
|
def upscaler_to_index(name: str):
|
||||||
try:
|
try:
|
||||||
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}")
|
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in shared.sd_upscalers])}") from e
|
||||||
|
|
||||||
|
|
||||||
def script_name_to_index(name, scripts):
|
def script_name_to_index(name, scripts):
|
||||||
try:
|
try:
|
||||||
return [script.title().lower() for script in scripts].index(name.lower())
|
return [script.title().lower() for script in scripts].index(name.lower())
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=422, detail=f"Script '{name}' not found")
|
raise HTTPException(status_code=422, detail=f"Script '{name}' not found") from e
|
||||||
|
|
||||||
|
|
||||||
def validate_sampler_name(name):
|
def validate_sampler_name(name):
|
||||||
config = sd_samplers.all_samplers_map.get(name, None)
|
config = sd_samplers.all_samplers_map.get(name, None)
|
||||||
@ -49,20 +52,23 @@ def validate_sampler_name(name):
|
|||||||
|
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def setUpscalers(req: dict):
|
def setUpscalers(req: dict):
|
||||||
reqDict = vars(req)
|
reqDict = vars(req)
|
||||||
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
|
||||||
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
|
||||||
return reqDict
|
return reqDict
|
||||||
|
|
||||||
|
|
||||||
def decode_base64_to_image(encoding):
|
def decode_base64_to_image(encoding):
|
||||||
if encoding.startswith("data:image/"):
|
if encoding.startswith("data:image/"):
|
||||||
encoding = encoding.split(";")[1].split(",")[1]
|
encoding = encoding.split(";")[1].split(",")[1]
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
image = Image.open(BytesIO(base64.b64decode(encoding)))
|
||||||
return image
|
return image
|
||||||
except Exception as err:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail="Invalid encoded image")
|
raise HTTPException(status_code=500, detail="Invalid encoded image") from e
|
||||||
|
|
||||||
|
|
||||||
def encode_pil_to_base64(image):
|
def encode_pil_to_base64(image):
|
||||||
with io.BytesIO() as output_bytes:
|
with io.BytesIO() as output_bytes:
|
||||||
@ -93,6 +99,7 @@ def encode_pil_to_base64(image):
|
|||||||
|
|
||||||
return base64.b64encode(bytes_data)
|
return base64.b64encode(bytes_data)
|
||||||
|
|
||||||
|
|
||||||
def api_middleware(app: FastAPI):
|
def api_middleware(app: FastAPI):
|
||||||
rich_available = True
|
rich_available = True
|
||||||
try:
|
try:
|
||||||
@ -100,7 +107,7 @@ def api_middleware(app: FastAPI):
|
|||||||
import starlette # importing just so it can be placed on silent list
|
import starlette # importing just so it can be placed on silent list
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
console = Console()
|
console = Console()
|
||||||
except:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
rich_available = False
|
rich_available = False
|
||||||
|
|
||||||
@ -131,8 +138,8 @@ def api_middleware(app: FastAPI):
|
|||||||
"body": vars(e).get('body', ''),
|
"body": vars(e).get('body', ''),
|
||||||
"errors": str(e),
|
"errors": str(e),
|
||||||
}
|
}
|
||||||
print(f"API error: {request.method}: {request.url} {err}")
|
|
||||||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions
|
||||||
|
print(f"API error: {request.method}: {request.url} {err}")
|
||||||
if rich_available:
|
if rich_available:
|
||||||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200]))
|
||||||
else:
|
else:
|
||||||
@ -158,7 +165,7 @@ def api_middleware(app: FastAPI):
|
|||||||
class Api:
|
class Api:
|
||||||
def __init__(self, app: FastAPI, queue_lock: Lock):
|
def __init__(self, app: FastAPI, queue_lock: Lock):
|
||||||
if shared.cmd_opts.api_auth:
|
if shared.cmd_opts.api_auth:
|
||||||
self.credentials = dict()
|
self.credentials = {}
|
||||||
for auth in shared.cmd_opts.api_auth.split(","):
|
for auth in shared.cmd_opts.api_auth.split(","):
|
||||||
user, password = auth.split(":")
|
user, password = auth.split(":")
|
||||||
self.credentials[user] = password
|
self.credentials[user] = password
|
||||||
@ -167,36 +174,37 @@ class Api:
|
|||||||
self.app = app
|
self.app = app
|
||||||
self.queue_lock = queue_lock
|
self.queue_lock = queue_lock
|
||||||
api_middleware(self.app)
|
api_middleware(self.app)
|
||||||
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
|
self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=models.TextToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
|
self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=models.ImageToImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
|
self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
|
||||||
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
|
self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
|
||||||
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
|
self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse)
|
||||||
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
|
self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse)
|
||||||
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
|
self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel)
|
||||||
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
|
self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel)
|
||||||
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem])
|
||||||
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem])
|
||||||
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem])
|
||||||
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem])
|
||||||
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem])
|
||||||
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
|
||||||
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
|
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
|
||||||
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
|
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
|
||||||
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
|
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
|
||||||
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse)
|
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
|
self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse)
|
||||||
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
|
self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse)
|
||||||
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
|
||||||
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
|
self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList)
|
||||||
|
self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo])
|
||||||
|
|
||||||
self.default_script_arg_txt2img = []
|
self.default_script_arg_txt2img = []
|
||||||
self.default_script_arg_img2img = []
|
self.default_script_arg_img2img = []
|
||||||
@ -220,17 +228,25 @@ class Api:
|
|||||||
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
script_idx = script_name_to_index(script_name, script_runner.selectable_scripts)
|
||||||
script = script_runner.selectable_scripts[script_idx]
|
script = script_runner.selectable_scripts[script_idx]
|
||||||
return script, script_idx
|
return script, script_idx
|
||||||
|
|
||||||
def get_scripts_list(self):
|
|
||||||
t2ilist = [str(title.lower()) for title in scripts.scripts_txt2img.titles]
|
|
||||||
i2ilist = [str(title.lower()) for title in scripts.scripts_img2img.titles]
|
|
||||||
|
|
||||||
return ScriptsList(txt2img = t2ilist, img2img = i2ilist)
|
def get_scripts_list(self):
|
||||||
|
t2ilist = [script.name for script in scripts.scripts_txt2img.scripts if script.name is not None]
|
||||||
|
i2ilist = [script.name for script in scripts.scripts_img2img.scripts if script.name is not None]
|
||||||
|
|
||||||
|
return models.ScriptsList(txt2img=t2ilist, img2img=i2ilist)
|
||||||
|
|
||||||
|
def get_script_info(self):
|
||||||
|
res = []
|
||||||
|
|
||||||
|
for script_list in [scripts.scripts_txt2img.scripts, scripts.scripts_img2img.scripts]:
|
||||||
|
res += [script.api_info for script in script_list if script.api_info is not None]
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def get_script(self, script_name, script_runner):
|
def get_script(self, script_name, script_runner):
|
||||||
if script_name is None or script_name == "":
|
if script_name is None or script_name == "":
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
script_idx = script_name_to_index(script_name, script_runner.scripts)
|
||||||
return script_runner.scripts[script_idx]
|
return script_runner.scripts[script_idx]
|
||||||
|
|
||||||
@ -265,17 +281,19 @@ class Api:
|
|||||||
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
if request.alwayson_scripts and (len(request.alwayson_scripts) > 0):
|
||||||
for alwayson_script_name in request.alwayson_scripts.keys():
|
for alwayson_script_name in request.alwayson_scripts.keys():
|
||||||
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
alwayson_script = self.get_script(alwayson_script_name, script_runner)
|
||||||
if alwayson_script == None:
|
if alwayson_script is None:
|
||||||
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
raise HTTPException(status_code=422, detail=f"always on script {alwayson_script_name} not found")
|
||||||
# Selectable script in always on script param check
|
# Selectable script in always on script param check
|
||||||
if alwayson_script.alwayson == False:
|
if alwayson_script.alwayson is False:
|
||||||
raise HTTPException(status_code=422, detail=f"Cannot have a selectable script in the always on scripts params")
|
raise HTTPException(status_code=422, detail="Cannot have a selectable script in the always on scripts params")
|
||||||
# always on script with no arg should always run so you don't really need to add them to the requests
|
# always on script with no arg should always run so you don't really need to add them to the requests
|
||||||
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
if "args" in request.alwayson_scripts[alwayson_script_name]:
|
||||||
script_args[alwayson_script.args_from:alwayson_script.args_to] = request.alwayson_scripts[alwayson_script_name]["args"]
|
# min between arg length in scriptrunner and arg length in the request
|
||||||
|
for idx in range(0, min((alwayson_script.args_to - alwayson_script.args_from), len(request.alwayson_scripts[alwayson_script_name]["args"]))):
|
||||||
|
script_args[alwayson_script.args_from + idx] = request.alwayson_scripts[alwayson_script_name]["args"][idx]
|
||||||
return script_args
|
return script_args
|
||||||
|
|
||||||
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
|
def text2imgapi(self, txt2imgreq: models.StableDiffusionTxt2ImgProcessingAPI):
|
||||||
script_runner = scripts.scripts_txt2img
|
script_runner = scripts.scripts_txt2img
|
||||||
if not script_runner.scripts:
|
if not script_runner.scripts:
|
||||||
script_runner.initialize_scripts(False)
|
script_runner.initialize_scripts(False)
|
||||||
@ -309,7 +327,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_txt2img_samples
|
p.outpath_samples = opts.outdir_txt2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -319,9 +337,9 @@ class Api:
|
|||||||
|
|
||||||
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
|
||||||
|
|
||||||
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
return models.TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
|
||||||
|
|
||||||
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
|
def img2imgapi(self, img2imgreq: models.StableDiffusionImg2ImgProcessingAPI):
|
||||||
init_images = img2imgreq.init_images
|
init_images = img2imgreq.init_images
|
||||||
if init_images is None:
|
if init_images is None:
|
||||||
raise HTTPException(status_code=404, detail="Init image not found")
|
raise HTTPException(status_code=404, detail="Init image not found")
|
||||||
@ -366,7 +384,7 @@ class Api:
|
|||||||
p.outpath_samples = opts.outdir_img2img_samples
|
p.outpath_samples = opts.outdir_img2img_samples
|
||||||
|
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
if selectable_scripts != None:
|
if selectable_scripts is not None:
|
||||||
p.script_args = script_args
|
p.script_args = script_args
|
||||||
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
|
||||||
else:
|
else:
|
||||||
@ -380,9 +398,9 @@ class Api:
|
|||||||
img2imgreq.init_images = None
|
img2imgreq.init_images = None
|
||||||
img2imgreq.mask = None
|
img2imgreq.mask = None
|
||||||
|
|
||||||
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
return models.ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
|
||||||
|
|
||||||
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
|
def extras_single_image_api(self, req: models.ExtrasSingleImageRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
reqDict['image'] = decode_base64_to_image(reqDict['image'])
|
||||||
@ -390,31 +408,26 @@ class Api:
|
|||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
return models.ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
|
||||||
|
|
||||||
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
|
def extras_batch_images_api(self, req: models.ExtrasBatchImagesRequest):
|
||||||
reqDict = setUpscalers(req)
|
reqDict = setUpscalers(req)
|
||||||
|
|
||||||
def prepareFiles(file):
|
image_list = reqDict.pop('imageList', [])
|
||||||
file = decode_base64_to_file(file.data, file_path=file.name)
|
image_folder = [decode_base64_to_image(x.data) for x in image_list]
|
||||||
file.orig_name = file.name
|
|
||||||
return file
|
|
||||||
|
|
||||||
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
|
|
||||||
reqDict.pop('imageList')
|
|
||||||
|
|
||||||
with self.queue_lock:
|
with self.queue_lock:
|
||||||
result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
result = postprocessing.run_extras(extras_mode=1, image_folder=image_folder, image="", input_dir="", output_dir="", save_output=False, **reqDict)
|
||||||
|
|
||||||
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
return models.ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
|
||||||
|
|
||||||
def pnginfoapi(self, req: PNGInfoRequest):
|
def pnginfoapi(self, req: models.PNGInfoRequest):
|
||||||
if(not req.image.strip()):
|
if(not req.image.strip()):
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
image = decode_base64_to_image(req.image.strip())
|
image = decode_base64_to_image(req.image.strip())
|
||||||
if image is None:
|
if image is None:
|
||||||
return PNGInfoResponse(info="")
|
return models.PNGInfoResponse(info="")
|
||||||
|
|
||||||
geninfo, items = images.read_info_from_image(image)
|
geninfo, items = images.read_info_from_image(image)
|
||||||
if geninfo is None:
|
if geninfo is None:
|
||||||
@ -422,13 +435,13 @@ class Api:
|
|||||||
|
|
||||||
items = {**{'parameters': geninfo}, **items}
|
items = {**{'parameters': geninfo}, **items}
|
||||||
|
|
||||||
return PNGInfoResponse(info=geninfo, items=items)
|
return models.PNGInfoResponse(info=geninfo, items=items)
|
||||||
|
|
||||||
def progressapi(self, req: ProgressRequest = Depends()):
|
def progressapi(self, req: models.ProgressRequest = Depends()):
|
||||||
# copy from check_progress_call of ui.py
|
# copy from check_progress_call of ui.py
|
||||||
|
|
||||||
if shared.state.job_count == 0:
|
if shared.state.job_count == 0:
|
||||||
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
# avoid dividing zero
|
# avoid dividing zero
|
||||||
progress = 0.01
|
progress = 0.01
|
||||||
@ -450,9 +463,9 @@ class Api:
|
|||||||
if shared.state.current_image and not req.skip_current_image:
|
if shared.state.current_image and not req.skip_current_image:
|
||||||
current_image = encode_pil_to_base64(shared.state.current_image)
|
current_image = encode_pil_to_base64(shared.state.current_image)
|
||||||
|
|
||||||
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
return models.ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
def interrogateapi(self, interrogatereq: InterrogateRequest):
|
def interrogateapi(self, interrogatereq: models.InterrogateRequest):
|
||||||
image_b64 = interrogatereq.image
|
image_b64 = interrogatereq.image
|
||||||
if image_b64 is None:
|
if image_b64 is None:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
@ -469,7 +482,7 @@ class Api:
|
|||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="Model not found")
|
raise HTTPException(status_code=404, detail="Model not found")
|
||||||
|
|
||||||
return InterrogateResponse(caption=processed)
|
return models.InterrogateResponse(caption=processed)
|
||||||
|
|
||||||
def interruptapi(self):
|
def interruptapi(self):
|
||||||
shared.state.interrupt()
|
shared.state.interrupt()
|
||||||
@ -574,36 +587,36 @@ class Api:
|
|||||||
filename = create_embedding(**args) # create empty embedding
|
filename = create_embedding(**args) # create empty embedding
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create embedding filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create embedding error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create embedding error: {e}")
|
||||||
|
|
||||||
def create_hypernetwork(self, args: dict):
|
def create_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
filename = create_hypernetwork(**args) # create empty embedding
|
filename = create_hypernetwork(**args) # create empty embedding
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
|
return models.CreateResponse(info=f"create hypernetwork filename: {filename}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
|
return models.TrainResponse(info=f"create hypernetwork error: {e}")
|
||||||
|
|
||||||
def preprocess(self, args: dict):
|
def preprocess(self, args: dict):
|
||||||
try:
|
try:
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
preprocess(**args) # quick operation unless blip/booru interrogation is enabled
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess complete')
|
return models.PreprocessResponse(info = 'preprocess complete')
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: invalid token: {e}")
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
|
return models.PreprocessResponse(info=f"preprocess error: {e}")
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
|
return models.PreprocessResponse(info=f'preprocess error: {e}')
|
||||||
|
|
||||||
def train_embedding(self, args: dict):
|
def train_embedding(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -621,10 +634,10 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError as msg:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
|
return models.TrainResponse(info=f"train embedding error: {msg}")
|
||||||
|
|
||||||
def train_hypernetwork(self, args: dict):
|
def train_hypernetwork(self, args: dict):
|
||||||
try:
|
try:
|
||||||
@ -645,14 +658,15 @@ class Api:
|
|||||||
if not apply_optimizations:
|
if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
sd_hijack.apply_optimizations()
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
|
return models.TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
|
||||||
except AssertionError as msg:
|
except AssertionError:
|
||||||
shared.state.end()
|
shared.state.end()
|
||||||
return TrainResponse(info="train embedding error: {error}".format(error=error))
|
return models.TrainResponse(info=f"train embedding error: {error}")
|
||||||
|
|
||||||
def get_memory(self):
|
def get_memory(self):
|
||||||
try:
|
try:
|
||||||
import os, psutil
|
import os
|
||||||
|
import psutil
|
||||||
process = psutil.Process(os.getpid())
|
process = psutil.Process(os.getpid())
|
||||||
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values
|
||||||
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe
|
||||||
@ -679,10 +693,10 @@ class Api:
|
|||||||
'events': warnings,
|
'events': warnings,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
cuda = { 'error': 'unavailable' }
|
cuda = {'error': 'unavailable'}
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
cuda = { 'error': f'{err}' }
|
cuda = {'error': f'{err}'}
|
||||||
return MemoryResponse(ram = ram, cuda = cuda)
|
return models.MemoryResponse(ram=ram, cuda=cuda)
|
||||||
|
|
||||||
def launch(self, server_name, port):
|
def launch(self, server_name, port):
|
||||||
self.app.include_router(self.router)
|
self.app.include_router(self.router)
|
||||||
|
@ -223,8 +223,9 @@ for key in _options:
|
|||||||
if(_options[key].dest != 'help'):
|
if(_options[key].dest != 'help'):
|
||||||
flag = _options[key]
|
flag = _options[key]
|
||||||
_type = str
|
_type = str
|
||||||
if _options[key].default is not None: _type = type(_options[key].default)
|
if _options[key].default is not None:
|
||||||
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
|
_type = type(_options[key].default)
|
||||||
|
flags.update({flag.dest: (_type, Field(default=flag.default, description=flag.help))})
|
||||||
|
|
||||||
FlagsModel = create_model("Flags", **flags)
|
FlagsModel = create_model("Flags", **flags)
|
||||||
|
|
||||||
@ -286,6 +287,23 @@ class MemoryResponse(BaseModel):
|
|||||||
ram: dict = Field(title="RAM", description="System memory stats")
|
ram: dict = Field(title="RAM", description="System memory stats")
|
||||||
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats")
|
||||||
|
|
||||||
|
|
||||||
class ScriptsList(BaseModel):
|
class ScriptsList(BaseModel):
|
||||||
txt2img: list = Field(default=None,title="Txt2img", description="Titles of scripts (txt2img)")
|
txt2img: list = Field(default=None, title="Txt2img", description="Titles of scripts (txt2img)")
|
||||||
img2img: list = Field(default=None,title="Img2img", description="Titles of scripts (img2img)")
|
img2img: list = Field(default=None, title="Img2img", description="Titles of scripts (img2img)")
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptArg(BaseModel):
|
||||||
|
label: str = Field(default=None, title="Label", description="Name of the argument in UI")
|
||||||
|
value: Optional[Any] = Field(default=None, title="Value", description="Default value of the argument")
|
||||||
|
minimum: Optional[Any] = Field(default=None, title="Minimum", description="Minimum allowed value for the argumentin UI")
|
||||||
|
maximum: Optional[Any] = Field(default=None, title="Minimum", description="Maximum allowed value for the argumentin UI")
|
||||||
|
step: Optional[Any] = Field(default=None, title="Minimum", description="Step for changing value of the argumentin UI")
|
||||||
|
choices: Optional[List[str]] = Field(default=None, title="Choices", description="Possible values for the argument")
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptInfo(BaseModel):
|
||||||
|
name: str = Field(default=None, title="Name", description="Script name")
|
||||||
|
is_alwayson: bool = Field(default=None, title="IsAlwayson", description="Flag specifying whether this script is an alwayson script")
|
||||||
|
is_img2img: bool = Field(default=None, title="IsImg2img", description="Flag specifying whether this script is an img2img script")
|
||||||
|
args: List[ScriptArg] = Field(title="Arguments", description="List of script's arguments")
|
||||||
|
@ -35,6 +35,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
res = func(*args, **kwargs)
|
res = func(*args, **kwargs)
|
||||||
|
progress.record_results(id_task, res)
|
||||||
finally:
|
finally:
|
||||||
progress.finish_task(id_task)
|
progress.finish_task(id_task)
|
||||||
|
|
||||||
@ -59,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
max_debug_str_len = 131072 # (1024*1024)/8
|
max_debug_str_len = 131072 # (1024*1024)/8
|
||||||
|
|
||||||
print("Error completing request", file=sys.stderr)
|
print("Error completing request", file=sys.stderr)
|
||||||
argStr = f"Arguments: {str(args)} {str(kwargs)}"
|
argStr = f"Arguments: {args} {kwargs}"
|
||||||
print(argStr[:max_debug_str_len], file=sys.stderr)
|
print(argStr[:max_debug_str_len], file=sys.stderr)
|
||||||
if len(argStr) > max_debug_str_len:
|
if len(argStr) > max_debug_str_len:
|
||||||
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
|
||||||
@ -72,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
|
|||||||
if extra_outputs_array is None:
|
if extra_outputs_array is None:
|
||||||
extra_outputs_array = [None, '']
|
extra_outputs_array = [None, '']
|
||||||
|
|
||||||
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
|
error_message = f'{type(e).__name__}: {e}'
|
||||||
|
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
|
||||||
|
|
||||||
shared.state.skipped = False
|
shared.state.skipped = False
|
||||||
shared.state.interrupted = False
|
shared.state.interrupted = False
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, sd_default_config, sd_model_file # noqa: F401
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
@ -95,9 +95,12 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin(
|
|||||||
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None)
|
||||||
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None)
|
||||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
|
||||||
|
parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None)
|
||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
|
||||||
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True)
|
||||||
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions")
|
||||||
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
|
||||||
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
|
||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
|
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
|
||||||
|
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
|
||||||
|
@ -1,14 +1,12 @@
|
|||||||
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
from modules.codeformer.vqgan_arch import *
|
from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock
|
||||||
from basicsr.utils import get_root_logger
|
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def calc_mean_std(feat, eps=1e-5):
|
def calc_mean_std(feat, eps=1e-5):
|
||||||
@ -121,7 +119,7 @@ class TransformerSALayer(nn.Module):
|
|||||||
tgt_mask: Optional[Tensor] = None,
|
tgt_mask: Optional[Tensor] = None,
|
||||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
query_pos: Optional[Tensor] = None):
|
query_pos: Optional[Tensor] = None):
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
tgt2 = self.norm1(tgt)
|
tgt2 = self.norm1(tgt)
|
||||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
@ -161,10 +159,10 @@ class Fuse_sft_block(nn.Module):
|
|||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class CodeFormer(VQAutoEncoder):
|
class CodeFormer(VQAutoEncoder):
|
||||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||||
codebook_size=1024, latent_size=256,
|
codebook_size=1024, latent_size=256,
|
||||||
connect_list=['32', '64', '128', '256'],
|
connect_list=('32', '64', '128', '256'),
|
||||||
fix_modules=['quantize','generator']):
|
fix_modules=('quantize', 'generator')):
|
||||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||||
|
|
||||||
if fix_modules is not None:
|
if fix_modules is not None:
|
||||||
@ -181,14 +179,14 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||||
|
|
||||||
# transformer
|
# transformer
|
||||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||||
for _ in range(self.n_layers)])
|
for _ in range(self.n_layers)])
|
||||||
|
|
||||||
# logits_predict head
|
# logits_predict head
|
||||||
self.idx_pred_layer = nn.Sequential(
|
self.idx_pred_layer = nn.Sequential(
|
||||||
nn.LayerNorm(dim_embd),
|
nn.LayerNorm(dim_embd),
|
||||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||||
|
|
||||||
self.channels = {
|
self.channels = {
|
||||||
'16': 512,
|
'16': 512,
|
||||||
'32': 256,
|
'32': 256,
|
||||||
@ -223,7 +221,7 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
enc_feat_dict = {}
|
enc_feat_dict = {}
|
||||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||||
for i, block in enumerate(self.encoder.blocks):
|
for i, block in enumerate(self.encoder.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in out_list:
|
if i in out_list:
|
||||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||||
|
|
||||||
@ -268,11 +266,11 @@ class CodeFormer(VQAutoEncoder):
|
|||||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||||
|
|
||||||
for i, block in enumerate(self.generator.blocks):
|
for i, block in enumerate(self.generator.blocks):
|
||||||
x = block(x)
|
x = block(x)
|
||||||
if i in fuse_list: # fuse after i-th block
|
if i in fuse_list: # fuse after i-th block
|
||||||
f_size = str(x.shape[-1])
|
f_size = str(x.shape[-1])
|
||||||
if w>0:
|
if w>0:
|
||||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||||
out = x
|
out = x
|
||||||
# logits doesn't need softmax before cross_entropy loss
|
# logits doesn't need softmax before cross_entropy loss
|
||||||
return out, logits, lq_feat
|
return out, logits, lq_feat
|
||||||
|
@ -5,17 +5,15 @@ VQGAN code, adapted from the original created by the Unleashing Transformers aut
|
|||||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||||
|
|
||||||
'''
|
'''
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import copy
|
|
||||||
from basicsr.utils import get_root_logger
|
from basicsr.utils import get_root_logger
|
||||||
from basicsr.utils.registry import ARCH_REGISTRY
|
from basicsr.utils.registry import ARCH_REGISTRY
|
||||||
|
|
||||||
def normalize(in_channels):
|
def normalize(in_channels):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def swish(x):
|
def swish(x):
|
||||||
@ -212,15 +210,15 @@ class AttnBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q = q.reshape(b, c, h*w)
|
q = q.reshape(b, c, h*w)
|
||||||
q = q.permute(0, 2, 1)
|
q = q.permute(0, 2, 1)
|
||||||
k = k.reshape(b, c, h*w)
|
k = k.reshape(b, c, h*w)
|
||||||
w_ = torch.bmm(q, k)
|
w_ = torch.bmm(q, k)
|
||||||
w_ = w_ * (int(c)**(-0.5))
|
w_ = w_ * (int(c)**(-0.5))
|
||||||
w_ = F.softmax(w_, dim=2)
|
w_ = F.softmax(w_, dim=2)
|
||||||
|
|
||||||
# attend to values
|
# attend to values
|
||||||
v = v.reshape(b, c, h*w)
|
v = v.reshape(b, c, h*w)
|
||||||
w_ = w_.permute(0, 2, 1)
|
w_ = w_.permute(0, 2, 1)
|
||||||
h_ = torch.bmm(v, w_)
|
h_ = torch.bmm(v, w_)
|
||||||
h_ = h_.reshape(b, c, h, w)
|
h_ = h_.reshape(b, c, h, w)
|
||||||
|
|
||||||
@ -272,18 +270,18 @@ class Encoder(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Generator(nn.Module):
|
class Generator(nn.Module):
|
||||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.num_resolutions = len(self.ch_mult)
|
self.num_resolutions = len(self.ch_mult)
|
||||||
self.num_res_blocks = res_blocks
|
self.num_res_blocks = res_blocks
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions
|
||||||
self.in_channels = emb_dim
|
self.in_channels = emb_dim
|
||||||
self.out_channels = 3
|
self.out_channels = 3
|
||||||
@ -317,29 +315,29 @@ class Generator(nn.Module):
|
|||||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||||
|
|
||||||
self.blocks = nn.ModuleList(blocks)
|
self.blocks = nn.ModuleList(blocks)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ARCH_REGISTRY.register()
|
@ARCH_REGISTRY.register()
|
||||||
class VQAutoEncoder(nn.Module):
|
class VQAutoEncoder(nn.Module):
|
||||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256,
|
||||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger = get_root_logger()
|
logger = get_root_logger()
|
||||||
self.in_channels = 3
|
self.in_channels = 3
|
||||||
self.nf = nf
|
self.nf = nf
|
||||||
self.n_blocks = res_blocks
|
self.n_blocks = res_blocks
|
||||||
self.codebook_size = codebook_size
|
self.codebook_size = codebook_size
|
||||||
self.embed_dim = emb_dim
|
self.embed_dim = emb_dim
|
||||||
self.ch_mult = ch_mult
|
self.ch_mult = ch_mult
|
||||||
self.resolution = img_size
|
self.resolution = img_size
|
||||||
self.attn_resolutions = attn_resolutions
|
self.attn_resolutions = attn_resolutions or [16]
|
||||||
self.quantizer_type = quantizer
|
self.quantizer_type = quantizer
|
||||||
self.encoder = Encoder(
|
self.encoder = Encoder(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -365,11 +363,11 @@ class VQAutoEncoder(nn.Module):
|
|||||||
self.kl_weight
|
self.kl_weight
|
||||||
)
|
)
|
||||||
self.generator = Generator(
|
self.generator = Generator(
|
||||||
self.nf,
|
self.nf,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.ch_mult,
|
self.ch_mult,
|
||||||
self.n_blocks,
|
self.n_blocks,
|
||||||
self.resolution,
|
self.resolution,
|
||||||
self.attn_resolutions
|
self.attn_resolutions
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -434,4 +432,4 @@ class VQGANDiscriminator(nn.Module):
|
|||||||
raise ValueError('Wrong params!')
|
raise ValueError('Wrong params!')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.main(x)
|
return self.main(x)
|
||||||
|
@ -33,11 +33,9 @@ def setup_model(dirname):
|
|||||||
try:
|
try:
|
||||||
from torchvision.transforms.functional import normalize
|
from torchvision.transforms.functional import normalize
|
||||||
from modules.codeformer.codeformer_arch import CodeFormer
|
from modules.codeformer.codeformer_arch import CodeFormer
|
||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils import img2tensor, tensor2img
|
||||||
from basicsr.utils import imwrite, img2tensor, tensor2img
|
|
||||||
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
from facelib.detection.retinaface import retinaface
|
from facelib.detection.retinaface import retinaface
|
||||||
from modules.shared import cmd_opts
|
|
||||||
|
|
||||||
net_class = CodeFormer
|
net_class = CodeFormer
|
||||||
|
|
||||||
@ -96,7 +94,7 @@ def setup_model(dirname):
|
|||||||
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
||||||
self.face_helper.align_warp_face()
|
self.face_helper.align_warp_face()
|
||||||
|
|
||||||
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
for cropped_face in self.face_helper.cropped_faces:
|
||||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
|
||||||
|
202
modules/config_states.py
Normal file
202
modules/config_states.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
Supports saving and restoring webui and extensions from a known working set of commits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import OrderedDict
|
||||||
|
import git
|
||||||
|
|
||||||
|
from modules import shared, extensions
|
||||||
|
from modules.paths_internal import script_path, config_states_dir
|
||||||
|
|
||||||
|
|
||||||
|
all_config_states = OrderedDict()
|
||||||
|
|
||||||
|
|
||||||
|
def list_config_states():
|
||||||
|
global all_config_states
|
||||||
|
|
||||||
|
all_config_states.clear()
|
||||||
|
os.makedirs(config_states_dir, exist_ok=True)
|
||||||
|
|
||||||
|
config_states = []
|
||||||
|
for filename in os.listdir(config_states_dir):
|
||||||
|
if filename.endswith(".json"):
|
||||||
|
path = os.path.join(config_states_dir, filename)
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
j = json.load(f)
|
||||||
|
j["filepath"] = path
|
||||||
|
config_states.append(j)
|
||||||
|
|
||||||
|
config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
|
||||||
|
|
||||||
|
for cs in config_states:
|
||||||
|
timestamp = time.asctime(time.gmtime(cs["created_at"]))
|
||||||
|
name = cs.get("name", "Config")
|
||||||
|
full_name = f"{name}: {timestamp}"
|
||||||
|
all_config_states[full_name] = cs
|
||||||
|
|
||||||
|
return all_config_states
|
||||||
|
|
||||||
|
|
||||||
|
def get_webui_config():
|
||||||
|
webui_repo = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
|
webui_repo = git.Repo(script_path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
webui_remote = None
|
||||||
|
webui_commit_hash = None
|
||||||
|
webui_commit_date = None
|
||||||
|
webui_branch = None
|
||||||
|
if webui_repo and not webui_repo.bare:
|
||||||
|
try:
|
||||||
|
webui_remote = next(webui_repo.remote().urls, None)
|
||||||
|
head = webui_repo.head.commit
|
||||||
|
webui_commit_date = webui_repo.head.commit.committed_date
|
||||||
|
webui_commit_hash = head.hexsha
|
||||||
|
webui_branch = webui_repo.active_branch.name
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
webui_remote = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"remote": webui_remote,
|
||||||
|
"commit_hash": webui_commit_hash,
|
||||||
|
"commit_date": webui_commit_date,
|
||||||
|
"branch": webui_branch,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_extension_config():
|
||||||
|
ext_config = {}
|
||||||
|
|
||||||
|
for ext in extensions.extensions:
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"name": ext.name,
|
||||||
|
"path": ext.path,
|
||||||
|
"enabled": ext.enabled,
|
||||||
|
"is_builtin": ext.is_builtin,
|
||||||
|
"remote": ext.remote,
|
||||||
|
"commit_hash": ext.commit_hash,
|
||||||
|
"commit_date": ext.commit_date,
|
||||||
|
"branch": ext.branch,
|
||||||
|
"have_info_from_repo": ext.have_info_from_repo
|
||||||
|
}
|
||||||
|
|
||||||
|
ext_config[ext.name] = entry
|
||||||
|
|
||||||
|
return ext_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_config():
|
||||||
|
creation_time = datetime.now().timestamp()
|
||||||
|
webui_config = get_webui_config()
|
||||||
|
ext_config = get_extension_config()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"created_at": creation_time,
|
||||||
|
"webui": webui_config,
|
||||||
|
"extensions": ext_config
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def restore_webui_config(config):
|
||||||
|
print("* Restoring webui state...")
|
||||||
|
|
||||||
|
if "webui" not in config:
|
||||||
|
print("Error: No webui data saved to config")
|
||||||
|
return
|
||||||
|
|
||||||
|
webui_config = config["webui"]
|
||||||
|
|
||||||
|
if "commit_hash" not in webui_config:
|
||||||
|
print("Error: No commit saved to webui config")
|
||||||
|
return
|
||||||
|
|
||||||
|
webui_commit_hash = webui_config.get("commit_hash", None)
|
||||||
|
webui_repo = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if os.path.exists(os.path.join(script_path, ".git")):
|
||||||
|
webui_repo = git.Repo(script_path)
|
||||||
|
except Exception:
|
||||||
|
print(f"Error reading webui git info from {script_path}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
webui_repo.git.fetch(all=True)
|
||||||
|
webui_repo.git.reset(webui_commit_hash, hard=True)
|
||||||
|
print(f"* Restored webui to commit {webui_commit_hash}.")
|
||||||
|
except Exception:
|
||||||
|
print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr)
|
||||||
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_extension_config(config):
|
||||||
|
print("* Restoring extension state...")
|
||||||
|
|
||||||
|
if "extensions" not in config:
|
||||||
|
print("Error: No extension data saved to config")
|
||||||
|
return
|
||||||
|
|
||||||
|
ext_config = config["extensions"]
|
||||||
|
|
||||||
|
results = []
|
||||||
|
disabled = []
|
||||||
|
|
||||||
|
for ext in tqdm.tqdm(extensions.extensions):
|
||||||
|
if ext.is_builtin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
current_commit = ext.commit_hash
|
||||||
|
|
||||||
|
if ext.name not in ext_config:
|
||||||
|
ext.disabled = True
|
||||||
|
disabled.append(ext.name)
|
||||||
|
results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
entry = ext_config[ext.name]
|
||||||
|
|
||||||
|
if "commit_hash" in entry and entry["commit_hash"]:
|
||||||
|
try:
|
||||||
|
ext.fetch_and_reset_hard(entry["commit_hash"])
|
||||||
|
ext.read_info_from_repo()
|
||||||
|
if current_commit != entry["commit_hash"]:
|
||||||
|
results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
|
||||||
|
except Exception as ex:
|
||||||
|
results.append((ext, current_commit[:8], False, ex))
|
||||||
|
else:
|
||||||
|
results.append((ext, current_commit[:8], False, "No commit hash found in config"))
|
||||||
|
|
||||||
|
if not entry.get("enabled", False):
|
||||||
|
ext.disabled = True
|
||||||
|
disabled.append(ext.name)
|
||||||
|
else:
|
||||||
|
ext.disabled = False
|
||||||
|
|
||||||
|
shared.opts.disabled_extensions = disabled
|
||||||
|
shared.opts.save(shared.config_filename)
|
||||||
|
|
||||||
|
print("* Finished restoring extensions. Results:")
|
||||||
|
for ext, prev_commit, success, result in results:
|
||||||
|
if success:
|
||||||
|
print(f" + {ext.name}: {prev_commit} -> {result}")
|
||||||
|
else:
|
||||||
|
print(f" ! {ext.name}: FAILURE ({result})")
|
@ -2,7 +2,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
from modules import modelloader, paths, deepbooru_model, devices, images, shared
|
||||||
@ -79,7 +78,7 @@ class DeepDanbooru:
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
|
|
||||||
filtertags = set([x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")])
|
filtertags = {x.strip().replace(' ', '_') for x in shared.opts.deepbooru_filter_tags.split(",")}
|
||||||
|
|
||||||
for tag in [x for x in tags if x not in filtertags]:
|
for tag in [x for x in tags if x not in filtertags]:
|
||||||
probability = probability_dict[tag]
|
probability = probability_dict[tag]
|
||||||
|
@ -65,7 +65,7 @@ def enable_tf32():
|
|||||||
|
|
||||||
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
# enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't
|
||||||
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407
|
||||||
if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]):
|
if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
@ -92,14 +92,18 @@ def cond_cast_float(input):
|
|||||||
|
|
||||||
|
|
||||||
def randn(seed, shape):
|
def randn(seed, shape):
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
if device.type == 'mps':
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
|
||||||
def randn_without_seed(shape):
|
def randn_without_seed(shape):
|
||||||
if device.type == 'mps':
|
from modules.shared import opts
|
||||||
|
|
||||||
|
if opts.randn_source == "CPU" or device.type == 'mps':
|
||||||
return torch.randn(shape, device=cpu).to(device)
|
return torch.randn(shape, device=cpu).to(device)
|
||||||
return torch.randn(shape, device=device)
|
return torch.randn(shape, device=device)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from PIL import Image
|
|||||||
from basicsr.utils.download_util import load_file_from_url
|
from basicsr.utils.download_util import load_file_from_url
|
||||||
|
|
||||||
import modules.esrgan_model_arch as arch
|
import modules.esrgan_model_arch as arch
|
||||||
from modules import shared, modelloader, images, devices
|
from modules import modelloader, images, devices
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@ -16,9 +16,7 @@ def mod2normal(state_dict):
|
|||||||
# this code is copied from https://github.com/victorca25/iNNfer
|
# this code is copied from https://github.com/victorca25/iNNfer
|
||||||
if 'conv_first.weight' in state_dict:
|
if 'conv_first.weight' in state_dict:
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23):
|
|||||||
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
|
||||||
re8x = 0
|
re8x = 0
|
||||||
crt_net = {}
|
crt_net = {}
|
||||||
items = []
|
items = list(state_dict)
|
||||||
for k, v in state_dict.items():
|
|
||||||
items.append(k)
|
|
||||||
|
|
||||||
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
crt_net['model.0.weight'] = state_dict['conv_first.weight']
|
||||||
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
crt_net['model.0.bias'] = state_dict['conv_first.bias']
|
||||||
@ -156,13 +152,16 @@ class UpscalerESRGAN(Upscaler):
|
|||||||
|
|
||||||
def load_model(self, path: str):
|
def load_model(self, path: str):
|
||||||
if "http" in path:
|
if "http" in path:
|
||||||
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
|
filename = load_file_from_url(
|
||||||
file_name="%s.pth" % self.model_name,
|
url=self.model_url,
|
||||||
progress=True)
|
model_dir=self.model_path,
|
||||||
|
file_name=f"{self.model_name}.pth",
|
||||||
|
progress=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = path
|
filename = path
|
||||||
if not os.path.exists(filename) or filename is None:
|
if not os.path.exists(filename) or filename is None:
|
||||||
print("Unable to load %s from %s" % (self.model_path, filename))
|
print(f"Unable to load {self.model_path} from {filename}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import math
|
import math
|
||||||
import functools
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -38,7 +37,7 @@ class RRDBNet(nn.Module):
|
|||||||
elif upsample_mode == 'pixelshuffle':
|
elif upsample_mode == 'pixelshuffle':
|
||||||
upsample_block = pixelshuffle_block
|
upsample_block = pixelshuffle_block
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
|
||||||
if upscale == 3:
|
if upscale == 3:
|
||||||
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
|
||||||
else:
|
else:
|
||||||
@ -106,7 +105,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
|||||||
Modified options that can be used:
|
Modified options that can be used:
|
||||||
- "Partial Convolution based Padding" arXiv:1811.11718
|
- "Partial Convolution based Padding" arXiv:1811.11718
|
||||||
- "Spectral normalization" arXiv:1802.05957
|
- "Spectral normalization" arXiv:1802.05957
|
||||||
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
|
||||||
{Rakotonirina} and A. {Rasoanaivo}
|
{Rakotonirina} and A. {Rasoanaivo}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -171,7 +170,7 @@ class GaussianNoise(nn.Module):
|
|||||||
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
||||||
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
||||||
x = x + sampled_noise
|
x = x + sampled_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def conv1x1(in_planes, out_planes, stride=1):
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
@ -261,10 +260,10 @@ class Upsample(nn.Module):
|
|||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
if self.scale_factor is not None:
|
if self.scale_factor is not None:
|
||||||
info = 'scale_factor=' + str(self.scale_factor)
|
info = f'scale_factor={self.scale_factor}'
|
||||||
else:
|
else:
|
||||||
info = 'size=' + str(self.size)
|
info = f'size={self.size}'
|
||||||
info += ', mode=' + self.mode
|
info += f', mode={self.mode}'
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
@ -350,7 +349,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
|
|||||||
elif act_type == 'sigmoid': # [0, 1] range output
|
elif act_type == 'sigmoid': # [0, 1] range output
|
||||||
layer = nn.Sigmoid()
|
layer = nn.Sigmoid()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
raise NotImplementedError(f'activation layer [{act_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -372,7 +371,7 @@ def norm(norm_type, nc):
|
|||||||
elif norm_type == 'none':
|
elif norm_type == 'none':
|
||||||
def norm_layer(x): return Identity()
|
def norm_layer(x): return Identity()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -388,7 +387,7 @@ def pad(pad_type, padding):
|
|||||||
elif pad_type == 'zero':
|
elif pad_type == 'zero':
|
||||||
layer = nn.ZeroPad2d(padding)
|
layer = nn.ZeroPad2d(padding)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
@ -432,15 +431,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
|
|||||||
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
|
||||||
spectral_norm=False):
|
spectral_norm=False):
|
||||||
""" Conv layer with padding, normalization, activation """
|
""" Conv layer with padding, normalization, activation """
|
||||||
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
|
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
|
||||||
padding = get_valid_padding(kernel_size, dilation)
|
padding = get_valid_padding(kernel_size, dilation)
|
||||||
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
||||||
padding = padding if pad_type == 'zero' else 0
|
padding = padding if pad_type == 'zero' else 0
|
||||||
|
|
||||||
if convtype=='PartialConv2D':
|
if convtype=='PartialConv2D':
|
||||||
|
from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer
|
||||||
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='DeformConv2D':
|
elif convtype=='DeformConv2D':
|
||||||
|
from torchvision.ops import DeformConv2d # not tested
|
||||||
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||||
dilation=dilation, bias=bias, groups=groups)
|
dilation=dilation, bias=bias, groups=groups)
|
||||||
elif convtype=='Conv3D':
|
elif convtype=='Conv3D':
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import time
|
|
||||||
import git
|
import git
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.paths_internal import extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
|
||||||
|
|
||||||
extensions = []
|
extensions = []
|
||||||
|
|
||||||
@ -24,6 +24,8 @@ def active():
|
|||||||
|
|
||||||
|
|
||||||
class Extension:
|
class Extension:
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
def __init__(self, name, path, enabled=True, is_builtin=False):
|
def __init__(self, name, path, enabled=True, is_builtin=False):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.path = path
|
self.path = path
|
||||||
@ -31,16 +33,24 @@ class Extension:
|
|||||||
self.status = ''
|
self.status = ''
|
||||||
self.can_update = False
|
self.can_update = False
|
||||||
self.is_builtin = is_builtin
|
self.is_builtin = is_builtin
|
||||||
|
self.commit_hash = ''
|
||||||
|
self.commit_date = None
|
||||||
self.version = ''
|
self.version = ''
|
||||||
|
self.branch = None
|
||||||
self.remote = None
|
self.remote = None
|
||||||
self.have_info_from_repo = False
|
self.have_info_from_repo = False
|
||||||
|
|
||||||
def read_info_from_repo(self):
|
def read_info_from_repo(self):
|
||||||
if self.have_info_from_repo:
|
if self.is_builtin or self.have_info_from_repo:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.have_info_from_repo = True
|
with self.lock:
|
||||||
|
if self.have_info_from_repo:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.do_read_info_from_repo()
|
||||||
|
|
||||||
|
def do_read_info_from_repo(self):
|
||||||
repo = None
|
repo = None
|
||||||
try:
|
try:
|
||||||
if os.path.exists(os.path.join(self.path, ".git")):
|
if os.path.exists(os.path.join(self.path, ".git")):
|
||||||
@ -55,13 +65,18 @@ class Extension:
|
|||||||
try:
|
try:
|
||||||
self.status = 'unknown'
|
self.status = 'unknown'
|
||||||
self.remote = next(repo.remote().urls, None)
|
self.remote = next(repo.remote().urls, None)
|
||||||
head = repo.head.commit
|
self.commit_date = repo.head.commit.committed_date
|
||||||
ts = time.asctime(time.gmtime(repo.head.commit.committed_date))
|
if repo.active_branch:
|
||||||
self.version = f'{head.hexsha[:8]} ({ts})'
|
self.branch = repo.active_branch.name
|
||||||
|
self.commit_hash = repo.head.commit.hexsha
|
||||||
|
self.version = repo.git.describe("--always", "--tags") # compared to `self.commit_hash[:8]` this takes about 30% more time total but since we run it in parallel we don't care
|
||||||
|
|
||||||
except Exception:
|
except Exception as ex:
|
||||||
|
print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr)
|
||||||
self.remote = None
|
self.remote = None
|
||||||
|
|
||||||
|
self.have_info_from_repo = True
|
||||||
|
|
||||||
def list_files(self, subdir, extension):
|
def list_files(self, subdir, extension):
|
||||||
from modules import scripts
|
from modules import scripts
|
||||||
|
|
||||||
@ -82,18 +97,30 @@ class Extension:
|
|||||||
for fetch in repo.remote().fetch(dry_run=True):
|
for fetch in repo.remote().fetch(dry_run=True):
|
||||||
if fetch.flags != fetch.HEAD_UPTODATE:
|
if fetch.flags != fetch.HEAD_UPTODATE:
|
||||||
self.can_update = True
|
self.can_update = True
|
||||||
self.status = "behind"
|
self.status = "new commits"
|
||||||
return
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
origin = repo.rev_parse('origin')
|
||||||
|
if repo.head.commit != origin:
|
||||||
|
self.can_update = True
|
||||||
|
self.status = "behind HEAD"
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
self.can_update = False
|
||||||
|
self.status = "unknown (remote error)"
|
||||||
|
return
|
||||||
|
|
||||||
self.can_update = False
|
self.can_update = False
|
||||||
self.status = "latest"
|
self.status = "latest"
|
||||||
|
|
||||||
def fetch_and_reset_hard(self):
|
def fetch_and_reset_hard(self, commit='origin'):
|
||||||
repo = git.Repo(self.path)
|
repo = git.Repo(self.path)
|
||||||
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
# Fix: `error: Your local changes to the following files would be overwritten by merge`,
|
||||||
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
# because WSL2 Docker set 755 file permissions instead of 644, this results to the error.
|
||||||
repo.git.fetch(all=True)
|
repo.git.fetch(all=True)
|
||||||
repo.git.reset('origin', hard=True)
|
repo.git.reset(commit, hard=True)
|
||||||
|
self.have_info_from_repo = False
|
||||||
|
|
||||||
|
|
||||||
def list_extensions():
|
def list_extensions():
|
||||||
|
@ -91,7 +91,7 @@ def deactivate(p, extra_network_data):
|
|||||||
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
"""call deactivate for extra networks in extra_network_data in specified order, then call
|
||||||
deactivate for all remaining registered networks"""
|
deactivate for all remaining registered networks"""
|
||||||
|
|
||||||
for extra_network_name, extra_network_args in extra_network_data.items():
|
for extra_network_name in extra_network_data:
|
||||||
extra_network = extra_network_registry.get(extra_network_name, None)
|
extra_network = extra_network_registry.get(extra_network_name, None)
|
||||||
if extra_network is None:
|
if extra_network is None:
|
||||||
continue
|
continue
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from modules import extra_networks, shared, extra_networks
|
from modules import extra_networks, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
|
|
||||||
|
|
||||||
@ -9,8 +9,9 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
|
|||||||
def activate(self, p, params_list):
|
def activate(self, p, params_list):
|
||||||
additional = shared.opts.sd_hypernetwork
|
additional = shared.opts.sd_hypernetwork
|
||||||
|
|
||||||
if additional != "" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
|
||||||
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
|
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
|
||||||
|
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
|
||||||
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
|
||||||
|
|
||||||
names = []
|
names = []
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -71,7 +72,7 @@ def to_half(tensor, enable):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
|
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
|
||||||
shared.state.begin()
|
shared.state.begin()
|
||||||
shared.state.job = 'model-merge'
|
shared.state.job = 'model-merge'
|
||||||
|
|
||||||
@ -135,14 +136,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_instruct_pix2pix_model = False
|
result_is_instruct_pix2pix_model = False
|
||||||
|
|
||||||
if theta_func2:
|
if theta_func2:
|
||||||
shared.state.textinfo = f"Loading B"
|
shared.state.textinfo = "Loading B"
|
||||||
print(f"Loading {secondary_model_info.filename}...")
|
print(f"Loading {secondary_model_info.filename}...")
|
||||||
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
|
||||||
else:
|
else:
|
||||||
theta_1 = None
|
theta_1 = None
|
||||||
|
|
||||||
if theta_func1:
|
if theta_func1:
|
||||||
shared.state.textinfo = f"Loading C"
|
shared.state.textinfo = "Loading C"
|
||||||
print(f"Loading {tertiary_model_info.filename}...")
|
print(f"Loading {tertiary_model_info.filename}...")
|
||||||
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
|
||||||
|
|
||||||
@ -198,7 +199,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
result_is_inpainting_model = True
|
result_is_inpainting_model = True
|
||||||
else:
|
else:
|
||||||
theta_0[key] = theta_func2(a, b, multiplier)
|
theta_0[key] = theta_func2(a, b, multiplier)
|
||||||
|
|
||||||
theta_0[key] = to_half(theta_0[key], save_as_half)
|
theta_0[key] = to_half(theta_0[key], save_as_half)
|
||||||
|
|
||||||
shared.state.sampling_step += 1
|
shared.state.sampling_step += 1
|
||||||
@ -241,13 +242,58 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
|
|||||||
shared.state.textinfo = "Saving"
|
shared.state.textinfo = "Saving"
|
||||||
print(f"Saving to {output_modelname}...")
|
print(f"Saving to {output_modelname}...")
|
||||||
|
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
if save_metadata:
|
||||||
|
metadata = {"format": "pt"}
|
||||||
|
|
||||||
|
merge_recipe = {
|
||||||
|
"type": "webui", # indicate this model was merged with webui's built-in merger
|
||||||
|
"primary_model_hash": primary_model_info.sha256,
|
||||||
|
"secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
|
||||||
|
"tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
|
||||||
|
"interp_method": interp_method,
|
||||||
|
"multiplier": multiplier,
|
||||||
|
"save_as_half": save_as_half,
|
||||||
|
"custom_name": custom_name,
|
||||||
|
"config_source": config_source,
|
||||||
|
"bake_in_vae": bake_in_vae,
|
||||||
|
"discard_weights": discard_weights,
|
||||||
|
"is_inpainting": result_is_inpainting_model,
|
||||||
|
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
|
||||||
|
}
|
||||||
|
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
|
||||||
|
|
||||||
|
sd_merge_models = {}
|
||||||
|
|
||||||
|
def add_model_metadata(checkpoint_info):
|
||||||
|
checkpoint_info.calculate_shorthash()
|
||||||
|
sd_merge_models[checkpoint_info.sha256] = {
|
||||||
|
"name": checkpoint_info.name,
|
||||||
|
"legacy_hash": checkpoint_info.hash,
|
||||||
|
"sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_merge_models.update(checkpoint_info.metadata.get("sd_merge_models", {}))
|
||||||
|
|
||||||
|
add_model_metadata(primary_model_info)
|
||||||
|
if secondary_model_info:
|
||||||
|
add_model_metadata(secondary_model_info)
|
||||||
|
if tertiary_model_info:
|
||||||
|
add_model_metadata(tertiary_model_info)
|
||||||
|
|
||||||
|
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
|
||||||
|
|
||||||
_, extension = os.path.splitext(output_modelname)
|
_, extension = os.path.splitext(output_modelname)
|
||||||
if extension.lower() == ".safetensors":
|
if extension.lower() == ".safetensors":
|
||||||
safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
|
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(theta_0, output_modelname)
|
torch.save(theta_0, output_modelname)
|
||||||
|
|
||||||
sd_models.list_models()
|
sd_models.list_models()
|
||||||
|
created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
|
||||||
|
if created_model:
|
||||||
|
created_model.calculate_shorthash()
|
||||||
|
|
||||||
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
|
||||||
|
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import html
|
|
||||||
import io
|
import io
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from modules.paths import data_path
|
from modules.paths import data_path
|
||||||
from modules import shared, ui_tempdir, script_callbacks
|
from modules import shared, ui_tempdir, script_callbacks
|
||||||
import tempfile
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
|
||||||
@ -23,14 +19,14 @@ registered_param_bindings = []
|
|||||||
|
|
||||||
|
|
||||||
class ParamBinding:
|
class ParamBinding:
|
||||||
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
|
def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
|
||||||
self.paste_button = paste_button
|
self.paste_button = paste_button
|
||||||
self.tabname = tabname
|
self.tabname = tabname
|
||||||
self.source_text_component = source_text_component
|
self.source_text_component = source_text_component
|
||||||
self.source_image_component = source_image_component
|
self.source_image_component = source_image_component
|
||||||
self.source_tabname = source_tabname
|
self.source_tabname = source_tabname
|
||||||
self.override_settings_component = override_settings_component
|
self.override_settings_component = override_settings_component
|
||||||
self.paste_field_names = paste_field_names
|
self.paste_field_names = paste_field_names or []
|
||||||
|
|
||||||
|
|
||||||
def reset():
|
def reset():
|
||||||
@ -59,6 +55,7 @@ def image_from_url_text(filedata):
|
|||||||
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
|
||||||
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
|
||||||
|
|
||||||
|
filename = filename.rsplit('?', 1)[0]
|
||||||
return Image.open(filename)
|
return Image.open(filename)
|
||||||
|
|
||||||
if type(filedata) == list:
|
if type(filedata) == list:
|
||||||
@ -129,6 +126,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=jsfunc,
|
_js=jsfunc,
|
||||||
inputs=[binding.source_image_component],
|
inputs=[binding.source_image_component],
|
||||||
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if binding.source_text_component is not None and fields is not None:
|
if binding.source_text_component is not None and fields is not None:
|
||||||
@ -140,6 +138,7 @@ def connect_paste_params_buttons():
|
|||||||
fn=lambda *x: x,
|
fn=lambda *x: x,
|
||||||
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
|
||||||
outputs=[field for field, name in fields if name in paste_field_names],
|
outputs=[field for field, name in fields if name in paste_field_names],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
binding.paste_button.click(
|
binding.paste_button.click(
|
||||||
@ -147,6 +146,7 @@ def connect_paste_params_buttons():
|
|||||||
_js=f"switch_to_{binding.tabname}",
|
_js=f"switch_to_{binding.tabname}",
|
||||||
inputs=None,
|
inputs=None,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -247,7 +247,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
lines.append(lastline)
|
lines.append(lastline)
|
||||||
lastline = ''
|
lastline = ''
|
||||||
|
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if line.startswith("Negative prompt:"):
|
if line.startswith("Negative prompt:"):
|
||||||
done_with_prompt = True
|
done_with_prompt = True
|
||||||
@ -265,8 +265,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
|
||||||
m = re_imagesize.match(v)
|
m = re_imagesize.match(v)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
res[k+"-1"] = m.group(1)
|
res[f"{k}-1"] = m.group(1)
|
||||||
res[k+"-2"] = m.group(2)
|
res[f"{k}-2"] = m.group(2)
|
||||||
else:
|
else:
|
||||||
res[k] = v
|
res[k] = v
|
||||||
|
|
||||||
@ -284,6 +284,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
|
|||||||
|
|
||||||
restore_old_hires_fix_params(res)
|
restore_old_hires_fix_params(res)
|
||||||
|
|
||||||
|
# Missing RNG means the default was set, which is GPU RNG
|
||||||
|
if "RNG" not in res:
|
||||||
|
res["RNG"] = "GPU"
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@ -304,6 +308,10 @@ infotext_to_setting_name_mapping = [
|
|||||||
('UniPC skip type', 'uni_pc_skip_type'),
|
('UniPC skip type', 'uni_pc_skip_type'),
|
||||||
('UniPC order', 'uni_pc_order'),
|
('UniPC order', 'uni_pc_order'),
|
||||||
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
('UniPC lower order final', 'uni_pc_lower_order_final'),
|
||||||
|
('Token merging ratio', 'token_merging_ratio'),
|
||||||
|
('Token merging ratio hr', 'token_merging_ratio_hr'),
|
||||||
|
('RNG', 'randn_source'),
|
||||||
|
('NGMS', 's_min_uncond'),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -403,12 +411,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
|
|||||||
fn=paste_func,
|
fn=paste_func,
|
||||||
inputs=[input_comp],
|
inputs=[input_comp],
|
||||||
outputs=[x[0] for x in paste_fields],
|
outputs=[x[0] for x in paste_fields],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
button.click(
|
button.click(
|
||||||
fn=None,
|
fn=None,
|
||||||
_js=f"recalculate_prompts_{tabname}",
|
_js=f"recalculate_prompts_{tabname}",
|
||||||
inputs=[],
|
inputs=[],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
show_progress=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +78,7 @@ def setup_model(dirname):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
from facexlib import detection, parsing
|
from facexlib import detection, parsing # noqa: F401
|
||||||
global user_path
|
global user_path
|
||||||
global have_gfpgan
|
global have_gfpgan
|
||||||
global gfpgan_constructor
|
global gfpgan_constructor
|
||||||
|
@ -13,7 +13,7 @@ cache_data = None
|
|||||||
|
|
||||||
|
|
||||||
def dump_cache():
|
def dump_cache():
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
with open(cache_filename, "w", encoding="utf8") as file:
|
with open(cache_filename, "w", encoding="utf8") as file:
|
||||||
json.dump(cache_data, file, indent=4)
|
json.dump(cache_data, file, indent=4)
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ def cache(subsection):
|
|||||||
global cache_data
|
global cache_data
|
||||||
|
|
||||||
if cache_data is None:
|
if cache_data is None:
|
||||||
with filelock.FileLock(cache_filename+".lock"):
|
with filelock.FileLock(f"{cache_filename}.lock"):
|
||||||
if not os.path.isfile(cache_filename):
|
if not os.path.isfile(cache_filename):
|
||||||
cache_data = {}
|
cache_data = {}
|
||||||
else:
|
else:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import csv
|
|
||||||
import datetime
|
import datetime
|
||||||
import glob
|
import glob
|
||||||
import html
|
import html
|
||||||
@ -18,7 +17,7 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
|
|||||||
from torch import einsum
|
from torch import einsum
|
||||||
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
|
||||||
|
|
||||||
from collections import defaultdict, deque
|
from collections import deque
|
||||||
from statistics import stdev, mean
|
from statistics import stdev, mean
|
||||||
|
|
||||||
|
|
||||||
@ -178,34 +177,34 @@ class Hypernetwork:
|
|||||||
|
|
||||||
def weights(self):
|
def weights(self):
|
||||||
res = []
|
res = []
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
res += layer.parameters()
|
res += layer.parameters()
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.train(mode=mode)
|
layer.train(mode=mode)
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
param.requires_grad = mode
|
param.requires_grad = mode
|
||||||
|
|
||||||
def to(self, device):
|
def to(self, device):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.to(device)
|
layer.to(device)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.multiplier = multiplier
|
layer.multiplier = multiplier
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
for k, layers in self.layers.items():
|
for layers in self.layers.values():
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
layer.eval()
|
layer.eval()
|
||||||
for param in layer.parameters():
|
for param in layer.parameters():
|
||||||
@ -404,7 +403,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
|
|||||||
k = self.to_k(context_k)
|
k = self.to_k(context_k)
|
||||||
v = self.to_v(context_v)
|
v = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
|
|
||||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
|
||||||
@ -541,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
return hypernetwork, filename
|
return hypernetwork, filename
|
||||||
|
|
||||||
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
|
||||||
|
|
||||||
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
|
||||||
@ -594,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
scaler = torch.cuda.amp.GradScaler()
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
|
|
||||||
batch_size = ds.batch_size
|
batch_size = ds.batch_size
|
||||||
gradient_step = ds.gradient_step
|
gradient_step = ds.gradient_step
|
||||||
# n steps = batch_size * gradient_step * n image processed
|
# n steps = batch_size * gradient_step * n image processed
|
||||||
@ -620,7 +619,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
try:
|
try:
|
||||||
sd_hijack_checkpoint.add()
|
sd_hijack_checkpoint.add()
|
||||||
|
|
||||||
for i in range((steps-initial_step) * gradient_step):
|
for _ in range((steps-initial_step) * gradient_step):
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
break
|
break
|
||||||
if shared.state.interrupted:
|
if shared.state.interrupted:
|
||||||
@ -637,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad_sched.step(hypernetwork.step)
|
clip_grad_sched.step(hypernetwork.step)
|
||||||
|
|
||||||
with devices.autocast():
|
with devices.autocast():
|
||||||
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
|
||||||
if use_weight:
|
if use_weight:
|
||||||
@ -658,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
|
|
||||||
_loss_step += loss.item()
|
_loss_step += loss.item()
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
# go back until we reach gradient accumulation steps
|
# go back until we reach gradient accumulation steps
|
||||||
if (j + 1) % gradient_step != 0:
|
if (j + 1) % gradient_step != 0:
|
||||||
continue
|
continue
|
||||||
loss_logging.append(_loss_step)
|
loss_logging.append(_loss_step)
|
||||||
if clip_grad:
|
if clip_grad:
|
||||||
clip_grad(weights, clip_grad_sched.learn_rate)
|
clip_grad(weights, clip_grad_sched.learn_rate)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
hypernetwork.step += 1
|
hypernetwork.step += 1
|
||||||
@ -675,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
|
|||||||
_loss_step = 0
|
_loss_step = 0
|
||||||
|
|
||||||
steps_done = hypernetwork.step + 1
|
steps_done = hypernetwork.step + 1
|
||||||
|
|
||||||
epoch_num = hypernetwork.step // steps_per_epoch
|
epoch_num = hypernetwork.step // steps_per_epoch
|
||||||
epoch_step = hypernetwork.step % steps_per_epoch
|
epoch_step = hypernetwork.step % steps_per_epoch
|
||||||
|
|
||||||
|
@ -1,19 +1,17 @@
|
|||||||
import html
|
import html
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.hypernetworks.hypernetwork
|
import modules.hypernetworks.hypernetwork
|
||||||
from modules import devices, sd_hijack, shared
|
from modules import devices, sd_hijack, shared
|
||||||
|
|
||||||
not_available = ["hardswish", "multiheadattention"]
|
not_available = ["hardswish", "multiheadattention"]
|
||||||
keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
|
keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available]
|
||||||
|
|
||||||
|
|
||||||
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
|
||||||
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure)
|
||||||
|
|
||||||
return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", ""
|
return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", ""
|
||||||
|
|
||||||
|
|
||||||
def train_hypernetwork(*args):
|
def train_hypernetwork(*args):
|
||||||
|
@ -13,17 +13,24 @@ import numpy as np
|
|||||||
import piexif
|
import piexif
|
||||||
import piexif.helper
|
import piexif.helper
|
||||||
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
|
||||||
from fonts.ttf import Roboto
|
|
||||||
import string
|
import string
|
||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
from modules import sd_samplers, shared, script_callbacks, errors
|
from modules import sd_samplers, shared, script_callbacks, errors
|
||||||
from modules.shared import opts, cmd_opts
|
from modules.paths_internal import roboto_ttf_file
|
||||||
|
from modules.shared import opts
|
||||||
|
|
||||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_font(fontsize: int):
|
||||||
|
try:
|
||||||
|
return ImageFont.truetype(opts.font or roboto_ttf_file, fontsize)
|
||||||
|
except Exception:
|
||||||
|
return ImageFont.truetype(roboto_ttf_file, fontsize)
|
||||||
|
|
||||||
|
|
||||||
def image_grid(imgs, batch_size=1, rows=None):
|
def image_grid(imgs, batch_size=1, rows=None):
|
||||||
if rows is None:
|
if rows is None:
|
||||||
if opts.n_rows > 0:
|
if opts.n_rows > 0:
|
||||||
@ -142,14 +149,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin=0):
|
|||||||
lines.append(word)
|
lines.append(word)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
def get_font(fontsize):
|
|
||||||
try:
|
|
||||||
return ImageFont.truetype(opts.font or Roboto, fontsize)
|
|
||||||
except Exception:
|
|
||||||
return ImageFont.truetype(Roboto, fontsize)
|
|
||||||
|
|
||||||
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
def draw_texts(drawing, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
||||||
for i, line in enumerate(lines):
|
for line in lines:
|
||||||
fnt = initial_fnt
|
fnt = initial_fnt
|
||||||
fontsize = initial_fontsize
|
fontsize = initial_fontsize
|
||||||
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
while drawing.multiline_textsize(line.text, font=fnt)[0] > line.allowed_width and fontsize > 0:
|
||||||
@ -318,6 +319,7 @@ re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
|
|||||||
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
|
||||||
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
|
||||||
max_filename_part_length = 128
|
max_filename_part_length = 128
|
||||||
|
NOTHING_AND_SKIP_PREVIOUS_TEXT = object()
|
||||||
|
|
||||||
|
|
||||||
def sanitize_filename_part(text, replace_spaces=True):
|
def sanitize_filename_part(text, replace_spaces=True):
|
||||||
@ -352,6 +354,11 @@ class FilenameGenerator:
|
|||||||
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
'prompt_no_styles': lambda self: self.prompt_no_style(),
|
||||||
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
|
||||||
'prompt_words': lambda self: self.prompt_words(),
|
'prompt_words': lambda self: self.prompt_words(),
|
||||||
|
'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1,
|
||||||
|
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
|
||||||
|
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
|
||||||
|
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
|
||||||
|
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
|
||||||
}
|
}
|
||||||
default_time_format = '%Y%m%d%H%M%S'
|
default_time_format = '%Y%m%d%H%M%S'
|
||||||
|
|
||||||
@ -361,6 +368,22 @@ class FilenameGenerator:
|
|||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
|
def hasprompt(self, *args):
|
||||||
|
lower = self.prompt.lower()
|
||||||
|
if self.p is None or self.prompt is None:
|
||||||
|
return None
|
||||||
|
outres = ""
|
||||||
|
for arg in args:
|
||||||
|
if arg != "":
|
||||||
|
division = arg.split("|")
|
||||||
|
expected = division[0].lower()
|
||||||
|
default = division[1] if len(division) > 1 else ""
|
||||||
|
if lower.find(expected) >= 0:
|
||||||
|
outres = f'{outres}{expected}'
|
||||||
|
else:
|
||||||
|
outres = outres if default == "" else f'{outres}{default}'
|
||||||
|
return sanitize_filename_part(outres)
|
||||||
|
|
||||||
def prompt_no_style(self):
|
def prompt_no_style(self):
|
||||||
if self.p is None or self.prompt is None:
|
if self.p is None or self.prompt is None:
|
||||||
return None
|
return None
|
||||||
@ -387,13 +410,13 @@ class FilenameGenerator:
|
|||||||
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
|
||||||
try:
|
try:
|
||||||
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
|
||||||
except pytz.exceptions.UnknownTimeZoneError as _:
|
except pytz.exceptions.UnknownTimeZoneError:
|
||||||
time_zone = None
|
time_zone = None
|
||||||
|
|
||||||
time_zone_time = time_datetime.astimezone(time_zone)
|
time_zone_time = time_datetime.astimezone(time_zone)
|
||||||
try:
|
try:
|
||||||
formatted_time = time_zone_time.strftime(time_format)
|
formatted_time = time_zone_time.strftime(time_format)
|
||||||
except (ValueError, TypeError) as _:
|
except (ValueError, TypeError):
|
||||||
formatted_time = time_zone_time.strftime(self.default_time_format)
|
formatted_time = time_zone_time.strftime(self.default_time_format)
|
||||||
|
|
||||||
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
return sanitize_filename_part(formatted_time, replace_spaces=False)
|
||||||
@ -403,9 +426,9 @@ class FilenameGenerator:
|
|||||||
|
|
||||||
for m in re_pattern.finditer(x):
|
for m in re_pattern.finditer(x):
|
||||||
text, pattern = m.groups()
|
text, pattern = m.groups()
|
||||||
res += text
|
|
||||||
|
|
||||||
if pattern is None:
|
if pattern is None:
|
||||||
|
res += text
|
||||||
continue
|
continue
|
||||||
|
|
||||||
pattern_args = []
|
pattern_args = []
|
||||||
@ -426,11 +449,13 @@ class FilenameGenerator:
|
|||||||
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
print(f"Error adding [{pattern}] to filename", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
if replacement is not None:
|
if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT:
|
||||||
res += str(replacement)
|
continue
|
||||||
|
elif replacement is not None:
|
||||||
|
res += text + str(replacement)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
res += f'[{pattern}]'
|
res += f'{text}[{pattern}]'
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -443,20 +468,57 @@ def get_next_sequence_number(path, basename):
|
|||||||
"""
|
"""
|
||||||
result = -1
|
result = -1
|
||||||
if basename != '':
|
if basename != '':
|
||||||
basename = basename + "-"
|
basename = f"{basename}-"
|
||||||
|
|
||||||
prefix_length = len(basename)
|
prefix_length = len(basename)
|
||||||
for p in os.listdir(path):
|
for p in os.listdir(path):
|
||||||
if p.startswith(basename):
|
if p.startswith(basename):
|
||||||
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
parts = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
|
||||||
try:
|
try:
|
||||||
result = max(int(l[0]), result)
|
result = max(int(parts[0]), result)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return result + 1
|
return result + 1
|
||||||
|
|
||||||
|
|
||||||
|
def save_image_with_geninfo(image, geninfo, filename, extension=None, existing_pnginfo=None):
|
||||||
|
if extension is None:
|
||||||
|
extension = os.path.splitext(filename)[1]
|
||||||
|
|
||||||
|
image_format = Image.registered_extensions()[extension]
|
||||||
|
|
||||||
|
existing_pnginfo = existing_pnginfo or {}
|
||||||
|
if opts.enable_pnginfo:
|
||||||
|
existing_pnginfo['parameters'] = geninfo
|
||||||
|
|
||||||
|
if extension.lower() == '.png':
|
||||||
|
pnginfo_data = PngImagePlugin.PngInfo()
|
||||||
|
for k, v in (existing_pnginfo or {}).items():
|
||||||
|
pnginfo_data.add_text(k, str(v))
|
||||||
|
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
||||||
|
|
||||||
|
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
||||||
|
if image.mode == 'RGBA':
|
||||||
|
image = image.convert("RGB")
|
||||||
|
elif image.mode == 'I;16':
|
||||||
|
image = image.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
||||||
|
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
||||||
|
|
||||||
|
if opts.enable_pnginfo and geninfo is not None:
|
||||||
|
exif_bytes = piexif.dump({
|
||||||
|
"Exif": {
|
||||||
|
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(geninfo or "", encoding="unicode")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
piexif.insert(exif_bytes, filename)
|
||||||
|
else:
|
||||||
|
image.save(filename, format=image_format, quality=opts.jpeg_quality)
|
||||||
|
|
||||||
|
|
||||||
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
|
||||||
"""Save an image.
|
"""Save an image.
|
||||||
|
|
||||||
@ -512,7 +574,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
add_number = opts.save_images_add_number or file_decoration == ''
|
add_number = opts.save_images_add_number or file_decoration == ''
|
||||||
|
|
||||||
if file_decoration != "" and add_number:
|
if file_decoration != "" and add_number:
|
||||||
file_decoration = "-" + file_decoration
|
file_decoration = f"-{file_decoration}"
|
||||||
|
|
||||||
file_decoration = namegen.apply(file_decoration) + suffix
|
file_decoration = namegen.apply(file_decoration) + suffix
|
||||||
|
|
||||||
@ -541,38 +603,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
info = params.pnginfo.get(pnginfo_section_name, None)
|
info = params.pnginfo.get(pnginfo_section_name, None)
|
||||||
|
|
||||||
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
def _atomically_save_image(image_to_save, filename_without_extension, extension):
|
||||||
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
"""
|
||||||
temp_file_path = filename_without_extension + ".tmp"
|
save image with .tmp extension to avoid race condition when another process detects new image in the directory
|
||||||
image_format = Image.registered_extensions()[extension]
|
"""
|
||||||
|
temp_file_path = f"{filename_without_extension}.tmp"
|
||||||
|
|
||||||
if extension.lower() == '.png':
|
save_image_with_geninfo(image_to_save, info, temp_file_path, extension, params.pnginfo)
|
||||||
pnginfo_data = PngImagePlugin.PngInfo()
|
|
||||||
if opts.enable_pnginfo:
|
|
||||||
for k, v in params.pnginfo.items():
|
|
||||||
pnginfo_data.add_text(k, str(v))
|
|
||||||
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
|
|
||||||
|
|
||||||
elif extension.lower() in (".jpg", ".jpeg", ".webp"):
|
|
||||||
if image_to_save.mode == 'RGBA':
|
|
||||||
image_to_save = image_to_save.convert("RGB")
|
|
||||||
elif image_to_save.mode == 'I;16':
|
|
||||||
image_to_save = image_to_save.point(lambda p: p * 0.0038910505836576).convert("RGB" if extension.lower() == ".webp" else "L")
|
|
||||||
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, lossless=opts.webp_lossless)
|
|
||||||
|
|
||||||
if opts.enable_pnginfo and info is not None:
|
|
||||||
exif_bytes = piexif.dump({
|
|
||||||
"Exif": {
|
|
||||||
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
piexif.insert(exif_bytes, temp_file_path)
|
|
||||||
else:
|
|
||||||
image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality)
|
|
||||||
|
|
||||||
# atomically rename the file with correct extension
|
|
||||||
os.replace(temp_file_path, filename_without_extension + extension)
|
os.replace(temp_file_path, filename_without_extension + extension)
|
||||||
|
|
||||||
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
fullfn_without_extension, extension = os.path.splitext(params.filename)
|
||||||
@ -602,7 +639,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
|
|||||||
if opts.save_txt and info is not None:
|
if opts.save_txt and info is not None:
|
||||||
txt_fullfn = f"{fullfn_without_extension}.txt"
|
txt_fullfn = f"{fullfn_without_extension}.txt"
|
||||||
with open(txt_fullfn, "w", encoding="utf8") as file:
|
with open(txt_fullfn, "w", encoding="utf8") as file:
|
||||||
file.write(info + "\n")
|
file.write(f"{info}\n")
|
||||||
else:
|
else:
|
||||||
txt_fullfn = None
|
txt_fullfn = None
|
||||||
|
|
||||||
|
@ -1,19 +1,15 @@
|
|||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops
|
from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
|
||||||
|
|
||||||
from modules import devices, sd_samplers
|
from modules import sd_samplers
|
||||||
from modules.generation_parameters_copypaste import create_override_settings_dict
|
from modules.generation_parameters_copypaste import create_override_settings_dict
|
||||||
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.shared import opts, state
|
from modules.shared import opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
import modules.processing as processing
|
import modules.processing as processing
|
||||||
from modules.ui import plaintext_to_html
|
from modules.ui import plaintext_to_html
|
||||||
import modules.images as images
|
|
||||||
import modules.scripts
|
import modules.scripts
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +42,11 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
if state.interrupted:
|
if state.interrupted:
|
||||||
break
|
break
|
||||||
|
|
||||||
img = Image.open(image)
|
try:
|
||||||
|
img = Image.open(image)
|
||||||
|
except UnidentifiedImageError as e:
|
||||||
|
print(e)
|
||||||
|
continue
|
||||||
# Use the EXIF orientation of photos taken by smartphones.
|
# Use the EXIF orientation of photos taken by smartphones.
|
||||||
img = ImageOps.exif_transpose(img)
|
img = ImageOps.exif_transpose(img)
|
||||||
p.init_images = [img] * p.batch_size
|
p.init_images = [img] * p.batch_size
|
||||||
@ -55,7 +55,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
# try to find corresponding mask for an image using simple filename matching
|
# try to find corresponding mask for an image using simple filename matching
|
||||||
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image))
|
||||||
# if not found use first one ("same mask for all images" use-case)
|
# if not found use first one ("same mask for all images" use-case)
|
||||||
if not mask_image_path in inpaint_masks:
|
if mask_image_path not in inpaint_masks:
|
||||||
mask_image_path = inpaint_masks[0]
|
mask_image_path = inpaint_masks[0]
|
||||||
mask_image = Image.open(mask_image_path)
|
mask_image = Image.open(mask_image_path)
|
||||||
p.image_mask = mask_image
|
p.image_mask = mask_image
|
||||||
@ -78,7 +78,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
|
|||||||
processed_image.save(os.path.join(output_dir, filename))
|
processed_image.save(os.path.join(output_dir, filename))
|
||||||
|
|
||||||
|
|
||||||
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args):
|
||||||
override_settings = create_override_settings_dict(override_settings_texts)
|
override_settings = create_override_settings_dict(override_settings_texts)
|
||||||
|
|
||||||
is_batch = mode == 5
|
is_batch = mode == 5
|
||||||
@ -114,6 +114,12 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
if image is not None:
|
if image is not None:
|
||||||
image = ImageOps.exif_transpose(image)
|
image = ImageOps.exif_transpose(image)
|
||||||
|
|
||||||
|
if selected_scale_tab == 1:
|
||||||
|
assert image, "Can't scale by because no image is selected"
|
||||||
|
|
||||||
|
width = int(image.width * scale_by)
|
||||||
|
height = int(image.height * scale_by)
|
||||||
|
|
||||||
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
|
|
||||||
p = StableDiffusionProcessingImg2Img(
|
p = StableDiffusionProcessingImg2Img(
|
||||||
@ -151,7 +157,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
|
|||||||
override_settings=override_settings,
|
override_settings=override_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
p.scripts = modules.scripts.scripts_txt2img
|
p.scripts = modules.scripts.scripts_img2img
|
||||||
p.script_args = args
|
p.script_args = args
|
||||||
|
|
||||||
if shared.cmd_opts.enable_console_prompts:
|
if shared.cmd_opts.enable_console_prompts:
|
||||||
|
@ -11,7 +11,6 @@ import torch.hub
|
|||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
import modules.shared as shared
|
|
||||||
from modules import devices, paths, shared, lowvram, modelloader, errors
|
from modules import devices, paths, shared, lowvram, modelloader, errors
|
||||||
|
|
||||||
blip_image_eval_size = 384
|
blip_image_eval_size = 384
|
||||||
@ -28,11 +27,11 @@ def category_types():
|
|||||||
def download_default_clip_interrogate_categories(content_dir):
|
def download_default_clip_interrogate_categories(content_dir):
|
||||||
print("Downloading CLIP categories...")
|
print("Downloading CLIP categories...")
|
||||||
|
|
||||||
tmpdir = content_dir + "_tmp"
|
tmpdir = f"{content_dir}_tmp"
|
||||||
category_types = ["artists", "flavors", "mediums", "movements"]
|
category_types = ["artists", "flavors", "mediums", "movements"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
os.makedirs(tmpdir)
|
os.makedirs(tmpdir, exist_ok=True)
|
||||||
for category_type in category_types:
|
for category_type in category_types:
|
||||||
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
|
||||||
os.rename(tmpdir, content_dir)
|
os.rename(tmpdir, content_dir)
|
||||||
@ -41,7 +40,7 @@ def download_default_clip_interrogate_categories(content_dir):
|
|||||||
errors.display(e, "downloading default CLIP interrogate categories")
|
errors.display(e, "downloading default CLIP interrogate categories")
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(tmpdir):
|
if os.path.exists(tmpdir):
|
||||||
os.remove(tmpdir)
|
os.removedirs(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
class InterrogateModels:
|
class InterrogateModels:
|
||||||
@ -160,7 +159,7 @@ class InterrogateModels:
|
|||||||
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
|
||||||
|
|
||||||
top_count = min(top_count, len(text_array))
|
top_count = min(top_count, len(text_array))
|
||||||
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
|
text_tokens = clip.tokenize(list(text_array), truncate=True).to(devices.device_interrogate)
|
||||||
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
|
||||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
@ -208,13 +207,13 @@ class InterrogateModels:
|
|||||||
|
|
||||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||||
|
|
||||||
for name, topn, items in self.categories():
|
for cat in self.categories():
|
||||||
matches = self.rank(image_features, items, top_count=topn)
|
matches = self.rank(image_features, cat.items, top_count=cat.topn)
|
||||||
for match, score in matches:
|
for match, score in matches:
|
||||||
if shared.opts.interrogate_return_ranks:
|
if shared.opts.interrogate_return_ranks:
|
||||||
res += f", ({match}:{score/100:.3f})"
|
res += f", ({match}:{score/100:.3f})"
|
||||||
else:
|
else:
|
||||||
res += ", " + match
|
res += f", {match}"
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print("Error interrogating", file=sys.stderr)
|
print("Error interrogating", file=sys.stderr)
|
||||||
|
@ -23,7 +23,7 @@ def list_localizations(dirname):
|
|||||||
localizations[fn] = file.path
|
localizations[fn] = file.path
|
||||||
|
|
||||||
|
|
||||||
def localization_js(current_localization_name):
|
def localization_js(current_localization_name: str) -> str:
|
||||||
fn = localizations.get(current_localization_name, None)
|
fn = localizations.get(current_localization_name, None)
|
||||||
data = {}
|
data = {}
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
@ -34,4 +34,4 @@ def localization_js(current_localization_name):
|
|||||||
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
print(f"Error loading localization from {fn}:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
return f"var localization = {json.dumps(data)}\n"
|
return f"window.localization = {json.dumps(data)}"
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
from modules import paths
|
|
||||||
from modules.sd_hijack_utils import CondFunc
|
from modules.sd_hijack_utils import CondFunc
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ if has_mps:
|
|||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
|
||||||
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
|
||||||
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
|
||||||
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
|
||||||
@ -54,6 +53,11 @@ if has_mps:
|
|||||||
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
|
||||||
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
|
||||||
if version.parse(torch.__version__) == version.parse("2.0"):
|
|
||||||
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
|
||||||
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
|
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
|
||||||
|
|
||||||
|
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
|
||||||
|
if platform.processor() == 'i386':
|
||||||
|
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
|
||||||
|
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
|
||||||
|
@ -4,7 +4,7 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
def get_crop_region(mask, pad=0):
|
def get_crop_region(mask, pad=0):
|
||||||
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
|
||||||
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
|
||||||
|
|
||||||
h, w = mask.shape
|
h, w = mask.shape
|
||||||
|
|
||||||
crop_left = 0
|
crop_left = 0
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import importlib
|
import importlib
|
||||||
@ -22,9 +21,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
"""
|
"""
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
if ext_filter is None:
|
|
||||||
ext_filter = []
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
places = []
|
places = []
|
||||||
|
|
||||||
@ -39,22 +35,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
|
|||||||
places.append(model_path)
|
places.append(model_path)
|
||||||
|
|
||||||
for place in places:
|
for place in places:
|
||||||
if os.path.exists(place):
|
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
|
||||||
for file in glob.iglob(place + '**/**', recursive=True):
|
if os.path.islink(full_path) and not os.path.exists(full_path):
|
||||||
full_path = file
|
print(f"Skipping broken symlink: {full_path}")
|
||||||
if os.path.isdir(full_path):
|
continue
|
||||||
continue
|
if ext_blacklist is not None and any(full_path.endswith(x) for x in ext_blacklist):
|
||||||
if os.path.islink(full_path) and not os.path.exists(full_path):
|
continue
|
||||||
print(f"Skipping broken symlink: {full_path}")
|
if full_path not in output:
|
||||||
continue
|
output.append(full_path)
|
||||||
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
|
|
||||||
continue
|
|
||||||
if len(ext_filter) != 0:
|
|
||||||
model_name, extension = os.path.splitext(file)
|
|
||||||
if extension not in ext_filter:
|
|
||||||
continue
|
|
||||||
if file not in output:
|
|
||||||
output.append(full_path)
|
|
||||||
|
|
||||||
if model_url is not None and len(output) == 0:
|
if model_url is not None and len(output) == 0:
|
||||||
if download_name is not None:
|
if download_name is not None:
|
||||||
@ -119,32 +107,15 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
|
|||||||
print(f"Moving {file} from {src_path} to {dest_path}.")
|
print(f"Moving {file} from {src_path} to {dest_path}.")
|
||||||
try:
|
try:
|
||||||
shutil.move(fullpath, dest_path)
|
shutil.move(fullpath, dest_path)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if len(os.listdir(src_path)) == 0:
|
if len(os.listdir(src_path)) == 0:
|
||||||
print(f"Removing empty folder: {src_path}")
|
print(f"Removing empty folder: {src_path}")
|
||||||
shutil.rmtree(src_path, True)
|
shutil.rmtree(src_path, True)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
builtin_upscaler_classes = []
|
|
||||||
forbidden_upscaler_classes = set()
|
|
||||||
|
|
||||||
|
|
||||||
def list_builtin_upscalers():
|
|
||||||
load_upscalers()
|
|
||||||
|
|
||||||
builtin_upscaler_classes.clear()
|
|
||||||
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
|
|
||||||
|
|
||||||
|
|
||||||
def forbid_loaded_nonbuiltin_upscalers():
|
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls not in builtin_upscaler_classes:
|
|
||||||
forbidden_upscaler_classes.add(cls)
|
|
||||||
|
|
||||||
|
|
||||||
def load_upscalers():
|
def load_upscalers():
|
||||||
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
|
||||||
# so we'll try to import any _model.py files before looking in __subclasses__
|
# so we'll try to import any _model.py files before looking in __subclasses__
|
||||||
@ -155,15 +126,22 @@ def load_upscalers():
|
|||||||
full_model = f"modules.{model_name}_model"
|
full_model = f"modules.{model_name}_model"
|
||||||
try:
|
try:
|
||||||
importlib.import_module(full_model)
|
importlib.import_module(full_model)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
datas = []
|
datas = []
|
||||||
commandline_options = vars(shared.cmd_opts)
|
commandline_options = vars(shared.cmd_opts)
|
||||||
for cls in Upscaler.__subclasses__():
|
|
||||||
if cls in forbidden_upscaler_classes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
# some of upscaler classes will not go away after reloading their modules, and we'll end
|
||||||
|
# up with two copies of those classes. The newest copy will always be the last in the list,
|
||||||
|
# so we go from end to beginning and ignore duplicates
|
||||||
|
used_classes = {}
|
||||||
|
for cls in reversed(Upscaler.__subclasses__()):
|
||||||
|
classname = str(cls)
|
||||||
|
if classname not in used_classes:
|
||||||
|
used_classes[classname] = cls
|
||||||
|
|
||||||
|
for cls in reversed(used_classes.values()):
|
||||||
name = cls.__name__
|
name = cls.__name__
|
||||||
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
|
||||||
scaler = cls(commandline_options.get(cmd_name, None))
|
scaler = cls(commandline_options.get(cmd_name, None))
|
||||||
|
@ -52,7 +52,7 @@ class DDPM(pl.LightningModule):
|
|||||||
beta_schedule="linear",
|
beta_schedule="linear",
|
||||||
loss_type="l2",
|
loss_type="l2",
|
||||||
ckpt_path=None,
|
ckpt_path=None,
|
||||||
ignore_keys=[],
|
ignore_keys=None,
|
||||||
load_only_unet=False,
|
load_only_unet=False,
|
||||||
monitor="val/loss",
|
monitor="val/loss",
|
||||||
use_ema=True,
|
use_ema=True,
|
||||||
@ -107,7 +107,7 @@ class DDPM(pl.LightningModule):
|
|||||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet)
|
||||||
|
|
||||||
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
# If initialing from EMA-only checkpoint, create EMA model after loading.
|
||||||
if self.use_ema and not load_ema:
|
if self.use_ema and not load_ema:
|
||||||
@ -194,7 +194,9 @@ class DDPM(pl.LightningModule):
|
|||||||
if context is not None:
|
if context is not None:
|
||||||
print(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
|
def init_from_ckpt(self, path, ignore_keys=None, only_model=False):
|
||||||
|
ignore_keys = ignore_keys or []
|
||||||
|
|
||||||
sd = torch.load(path, map_location="cpu")
|
sd = torch.load(path, map_location="cpu")
|
||||||
if "state_dict" in list(sd.keys()):
|
if "state_dict" in list(sd.keys()):
|
||||||
sd = sd["state_dict"]
|
sd = sd["state_dict"]
|
||||||
@ -223,7 +225,7 @@ class DDPM(pl.LightningModule):
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
if k.startswith(ik):
|
if k.startswith(ik):
|
||||||
print("Deleting key {} from state_dict.".format(k))
|
print(f"Deleting key {k} from state_dict.")
|
||||||
del sd[k]
|
del sd[k]
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
|
||||||
sd, strict=False)
|
sd, strict=False)
|
||||||
@ -386,7 +388,7 @@ class DDPM(pl.LightningModule):
|
|||||||
_, loss_dict_no_ema = self.shared_step(batch)
|
_, loss_dict_no_ema = self.shared_step(batch)
|
||||||
with self.ema_scope():
|
with self.ema_scope():
|
||||||
_, loss_dict_ema = self.shared_step(batch)
|
_, loss_dict_ema = self.shared_step(batch)
|
||||||
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
|
||||||
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
||||||
|
|
||||||
@ -403,7 +405,7 @@ class DDPM(pl.LightningModule):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
|
||||||
log = dict()
|
log = {}
|
||||||
x = self.get_input(batch, self.first_stage_key)
|
x = self.get_input(batch, self.first_stage_key)
|
||||||
N = min(x.shape[0], N)
|
N = min(x.shape[0], N)
|
||||||
n_row = min(x.shape[0], n_row)
|
n_row = min(x.shape[0], n_row)
|
||||||
@ -411,7 +413,7 @@ class DDPM(pl.LightningModule):
|
|||||||
log["inputs"] = x
|
log["inputs"] = x
|
||||||
|
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
x_start = x[:n_row]
|
x_start = x[:n_row]
|
||||||
|
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
@ -473,13 +475,13 @@ class LatentDiffusion(DDPM):
|
|||||||
conditioning_key = None
|
conditioning_key = None
|
||||||
ckpt_path = kwargs.pop("ckpt_path", None)
|
ckpt_path = kwargs.pop("ckpt_path", None)
|
||||||
ignore_keys = kwargs.pop("ignore_keys", [])
|
ignore_keys = kwargs.pop("ignore_keys", [])
|
||||||
super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs)
|
super().__init__(*args, conditioning_key=conditioning_key, load_ema=load_ema, **kwargs)
|
||||||
self.concat_mode = concat_mode
|
self.concat_mode = concat_mode
|
||||||
self.cond_stage_trainable = cond_stage_trainable
|
self.cond_stage_trainable = cond_stage_trainable
|
||||||
self.cond_stage_key = cond_stage_key
|
self.cond_stage_key = cond_stage_key
|
||||||
try:
|
try:
|
||||||
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
|
||||||
except:
|
except Exception:
|
||||||
self.num_downs = 0
|
self.num_downs = 0
|
||||||
if not scale_by_std:
|
if not scale_by_std:
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
@ -891,16 +893,6 @@ class LatentDiffusion(DDPM):
|
|||||||
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
|
||||||
return self.p_losses(x, c, t, *args, **kwargs)
|
return self.p_losses(x, c, t, *args, **kwargs)
|
||||||
|
|
||||||
def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
|
|
||||||
def rescale_bbox(bbox):
|
|
||||||
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
|
|
||||||
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
|
|
||||||
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
|
|
||||||
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
|
|
||||||
return x0, y0, w, h
|
|
||||||
|
|
||||||
return [rescale_bbox(b) for b in bboxes]
|
|
||||||
|
|
||||||
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
def apply_model(self, x_noisy, t, cond, return_ids=False):
|
||||||
|
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
@ -1140,7 +1132,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
|
|
||||||
@ -1171,8 +1163,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(x0_partial)
|
intermediates.append(x0_partial)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -1219,8 +1213,10 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if i % log_every_t == 0 or i == timesteps - 1:
|
if i % log_every_t == 0 or i == timesteps - 1:
|
||||||
intermediates.append(img)
|
intermediates.append(img)
|
||||||
if callback: callback(i)
|
if callback:
|
||||||
if img_callback: img_callback(img, i)
|
callback(i)
|
||||||
|
if img_callback:
|
||||||
|
img_callback(img, i)
|
||||||
|
|
||||||
if return_intermediates:
|
if return_intermediates:
|
||||||
return img, intermediates
|
return img, intermediates
|
||||||
@ -1235,7 +1231,7 @@ class LatentDiffusion(DDPM):
|
|||||||
if cond is not None:
|
if cond is not None:
|
||||||
if isinstance(cond, dict):
|
if isinstance(cond, dict):
|
||||||
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
|
||||||
list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
|
[x[:batch_size] for x in cond[key]] for key in cond}
|
||||||
else:
|
else:
|
||||||
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
|
||||||
return self.p_sample_loop(cond,
|
return self.p_sample_loop(cond,
|
||||||
@ -1267,7 +1263,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
use_ddim = False
|
use_ddim = False
|
||||||
|
|
||||||
log = dict()
|
log = {}
|
||||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
|
||||||
return_first_stage_outputs=True,
|
return_first_stage_outputs=True,
|
||||||
force_c_encode=True,
|
force_c_encode=True,
|
||||||
@ -1295,7 +1291,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if plot_diffusion_rows:
|
if plot_diffusion_rows:
|
||||||
# get diffusion row
|
# get diffusion row
|
||||||
diffusion_row = list()
|
diffusion_row = []
|
||||||
z_start = z[:n_row]
|
z_start = z[:n_row]
|
||||||
for t in range(self.num_timesteps):
|
for t in range(self.num_timesteps):
|
||||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||||
@ -1337,7 +1333,7 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
if inpaint:
|
if inpaint:
|
||||||
# make a simple center square
|
# make a simple center square
|
||||||
b, h, w = z.shape[0], z.shape[2], z.shape[3]
|
h, w = z.shape[2], z.shape[3]
|
||||||
mask = torch.ones(N, h, w).to(self.device)
|
mask = torch.ones(N, h, w).to(self.device)
|
||||||
# zeros will be filled in
|
# zeros will be filled in
|
||||||
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
|
||||||
@ -1439,10 +1435,10 @@ class Layout2ImgDiffusion(LatentDiffusion):
|
|||||||
# TODO: move all layout-specific hacks to this class
|
# TODO: move all layout-specific hacks to this class
|
||||||
def __init__(self, cond_stage_key, *args, **kwargs):
|
def __init__(self, cond_stage_key, *args, **kwargs):
|
||||||
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
|
||||||
super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
|
super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs)
|
||||||
|
|
||||||
def log_images(self, batch, N=8, *args, **kwargs):
|
def log_images(self, batch, N=8, *args, **kwargs):
|
||||||
logs = super().log_images(batch=batch, N=N, *args, **kwargs)
|
logs = super().log_images(*args, batch=batch, N=N, **kwargs)
|
||||||
|
|
||||||
key = 'train' if self.training else 'validation'
|
key = 'train' if self.training else 'validation'
|
||||||
dset = self.trainer.datamodule.datasets[key]
|
dset = self.trainer.datamodule.datasets[key]
|
||||||
|
@ -1 +1 @@
|
|||||||
from .sampler import UniPCSampler
|
from .sampler import UniPCSampler # noqa: F401
|
||||||
|
@ -54,7 +54,8 @@ class UniPCSampler(object):
|
|||||||
if conditioning is not None:
|
if conditioning is not None:
|
||||||
if isinstance(conditioning, dict):
|
if isinstance(conditioning, dict):
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
while isinstance(ctmp, list):
|
||||||
|
ctmp = ctmp[0]
|
||||||
cbs = ctmp.shape[0]
|
cbs = ctmp.shape[0]
|
||||||
if cbs != batch_size:
|
if cbs != batch_size:
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
import math
|
||||||
from tqdm.auto import trange
|
import tqdm
|
||||||
|
|
||||||
|
|
||||||
class NoiseScheduleVP:
|
class NoiseScheduleVP:
|
||||||
@ -94,7 +93,7 @@ class NoiseScheduleVP:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if schedule not in ['discrete', 'linear', 'cosine']:
|
if schedule not in ['discrete', 'linear', 'cosine']:
|
||||||
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
|
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
|
||||||
|
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
if schedule == 'discrete':
|
if schedule == 'discrete':
|
||||||
@ -179,13 +178,13 @@ def model_wrapper(
|
|||||||
model,
|
model,
|
||||||
noise_schedule,
|
noise_schedule,
|
||||||
model_type="noise",
|
model_type="noise",
|
||||||
model_kwargs={},
|
model_kwargs=None,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
#condition=None,
|
#condition=None,
|
||||||
#unconditional_condition=None,
|
#unconditional_condition=None,
|
||||||
guidance_scale=1.,
|
guidance_scale=1.,
|
||||||
classifier_fn=None,
|
classifier_fn=None,
|
||||||
classifier_kwargs={},
|
classifier_kwargs=None,
|
||||||
):
|
):
|
||||||
"""Create a wrapper function for the noise prediction model.
|
"""Create a wrapper function for the noise prediction model.
|
||||||
|
|
||||||
@ -276,6 +275,9 @@ def model_wrapper(
|
|||||||
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_kwargs = model_kwargs or {}
|
||||||
|
classifier_kwargs = classifier_kwargs or {}
|
||||||
|
|
||||||
def get_model_input_time(t_continuous):
|
def get_model_input_time(t_continuous):
|
||||||
"""
|
"""
|
||||||
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
||||||
@ -342,7 +344,7 @@ def model_wrapper(
|
|||||||
t_in = torch.cat([t_continuous] * 2)
|
t_in = torch.cat([t_continuous] * 2)
|
||||||
if isinstance(condition, dict):
|
if isinstance(condition, dict):
|
||||||
assert isinstance(unconditional_condition, dict)
|
assert isinstance(unconditional_condition, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in condition:
|
for k in condition:
|
||||||
if isinstance(condition[k], list):
|
if isinstance(condition[k], list):
|
||||||
c_in[k] = [torch.cat([
|
c_in[k] = [torch.cat([
|
||||||
@ -353,7 +355,7 @@ def model_wrapper(
|
|||||||
unconditional_condition[k],
|
unconditional_condition[k],
|
||||||
condition[k]])
|
condition[k]])
|
||||||
elif isinstance(condition, list):
|
elif isinstance(condition, list):
|
||||||
c_in = list()
|
c_in = []
|
||||||
assert isinstance(unconditional_condition, list)
|
assert isinstance(unconditional_condition, list)
|
||||||
for i in range(len(condition)):
|
for i in range(len(condition)):
|
||||||
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
c_in.append(torch.cat([unconditional_condition[i], condition[i]]))
|
||||||
@ -469,7 +471,7 @@ class UniPC:
|
|||||||
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
|
||||||
return t
|
return t
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
|
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
|
||||||
|
|
||||||
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
|
||||||
"""
|
"""
|
||||||
@ -757,40 +759,44 @@ class UniPC:
|
|||||||
vec_t = timesteps[0].expand((x.shape[0]))
|
vec_t = timesteps[0].expand((x.shape[0]))
|
||||||
model_prev_list = [self.model_fn(x, vec_t)]
|
model_prev_list = [self.model_fn(x, vec_t)]
|
||||||
t_prev_list = [vec_t]
|
t_prev_list = [vec_t]
|
||||||
# Init the first `order` values by lower order multistep DPM-Solver.
|
with tqdm.tqdm(total=steps) as pbar:
|
||||||
for init_order in range(1, order):
|
# Init the first `order` values by lower order multistep DPM-Solver.
|
||||||
vec_t = timesteps[init_order].expand(x.shape[0])
|
for init_order in range(1, order):
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
vec_t = timesteps[init_order].expand(x.shape[0])
|
||||||
if model_x is None:
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||||
model_x = self.model_fn(x, vec_t)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
model_prev_list.append(model_x)
|
|
||||||
t_prev_list.append(vec_t)
|
|
||||||
for step in trange(order, steps + 1):
|
|
||||||
vec_t = timesteps[step].expand(x.shape[0])
|
|
||||||
if lower_order_final:
|
|
||||||
step_order = min(order, steps + 1 - step)
|
|
||||||
else:
|
|
||||||
step_order = order
|
|
||||||
#print('this step order:', step_order)
|
|
||||||
if step == steps:
|
|
||||||
#print('do not run corrector at the last step')
|
|
||||||
use_corrector = False
|
|
||||||
else:
|
|
||||||
use_corrector = True
|
|
||||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
|
||||||
if self.after_update is not None:
|
|
||||||
self.after_update(x, model_x)
|
|
||||||
for i in range(order - 1):
|
|
||||||
t_prev_list[i] = t_prev_list[i + 1]
|
|
||||||
model_prev_list[i] = model_prev_list[i + 1]
|
|
||||||
t_prev_list[-1] = vec_t
|
|
||||||
# We do not need to evaluate the final model value.
|
|
||||||
if step < steps:
|
|
||||||
if model_x is None:
|
if model_x is None:
|
||||||
model_x = self.model_fn(x, vec_t)
|
model_x = self.model_fn(x, vec_t)
|
||||||
model_prev_list[-1] = model_x
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
model_prev_list.append(model_x)
|
||||||
|
t_prev_list.append(vec_t)
|
||||||
|
pbar.update()
|
||||||
|
|
||||||
|
for step in range(order, steps + 1):
|
||||||
|
vec_t = timesteps[step].expand(x.shape[0])
|
||||||
|
if lower_order_final:
|
||||||
|
step_order = min(order, steps + 1 - step)
|
||||||
|
else:
|
||||||
|
step_order = order
|
||||||
|
#print('this step order:', step_order)
|
||||||
|
if step == steps:
|
||||||
|
#print('do not run corrector at the last step')
|
||||||
|
use_corrector = False
|
||||||
|
else:
|
||||||
|
use_corrector = True
|
||||||
|
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||||
|
if self.after_update is not None:
|
||||||
|
self.after_update(x, model_x)
|
||||||
|
for i in range(order - 1):
|
||||||
|
t_prev_list[i] = t_prev_list[i + 1]
|
||||||
|
model_prev_list[i] = model_prev_list[i + 1]
|
||||||
|
t_prev_list[-1] = vec_t
|
||||||
|
# We do not need to evaluate the final model value.
|
||||||
|
if step < steps:
|
||||||
|
if model_x is None:
|
||||||
|
model_x = self.model_fn(x, vec_t)
|
||||||
|
model_prev_list[-1] = model_x
|
||||||
|
pbar.update()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
if denoise_to_zero:
|
if denoise_to_zero:
|
||||||
|
@ -7,12 +7,24 @@ def connect(token, port, region):
|
|||||||
else:
|
else:
|
||||||
if ':' in token:
|
if ':' in token:
|
||||||
# token = authtoken:username:password
|
# token = authtoken:username:password
|
||||||
account = token.split(':')[1] + ':' + token.split(':')[-1]
|
token, username, password = token.split(':', 2)
|
||||||
token = token.split(':')[0]
|
account = f"{username}:{password}"
|
||||||
|
|
||||||
config = conf.PyngrokConfig(
|
config = conf.PyngrokConfig(
|
||||||
auth_token=token, region=region
|
auth_token=token, region=region
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Guard for existing tunnels
|
||||||
|
existing = ngrok.get_tunnels(pyngrok_config=config)
|
||||||
|
if existing:
|
||||||
|
for established in existing:
|
||||||
|
# Extra configuration in the case that the user is also using ngrok for other tunnels
|
||||||
|
if established.config['addr'][-4:] == str(port):
|
||||||
|
public_url = existing[0].public_url
|
||||||
|
print(f'ngrok has already been connected to localhost:{port}! URL: {public_url}\n'
|
||||||
|
'You can use this link after the launch is complete.')
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if account is None:
|
if account is None:
|
||||||
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
public_url = ngrok.connect(port, pyngrok_config=config, bind_tls=True).public_url
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir
|
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir # noqa: F401
|
||||||
|
|
||||||
import modules.safe
|
import modules.safe # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
# data_path = cmd_opts_pre.data
|
# data_path = cmd_opts_pre.data
|
||||||
@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
|
|||||||
sd_path = os.path.abspath(possible_sd_path)
|
sd_path = os.path.abspath(possible_sd_path)
|
||||||
break
|
break
|
||||||
|
|
||||||
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
|
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
|
||||||
|
|
||||||
path_dirs = [
|
path_dirs = [
|
||||||
(sd_path, 'ldm', 'Stable Diffusion', []),
|
(sd_path, 'ldm', 'Stable Diffusion', []),
|
||||||
|
@ -2,8 +2,14 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import shlex
|
||||||
|
|
||||||
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
|
||||||
|
sys.argv += shlex.split(commandline_args)
|
||||||
|
|
||||||
|
modules_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
script_path = os.path.dirname(modules_path)
|
||||||
|
|
||||||
sd_configs_path = os.path.join(script_path, "configs")
|
sd_configs_path = os.path.join(script_path, "configs")
|
||||||
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
|
||||||
@ -12,7 +18,7 @@ default_sd_model_file = sd_model_file
|
|||||||
|
|
||||||
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
# Parse the --data-dir flag first so we can use it as a base for our other argument default values
|
||||||
parser_pre = argparse.ArgumentParser(add_help=False)
|
parser_pre = argparse.ArgumentParser(add_help=False)
|
||||||
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
|
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", )
|
||||||
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
cmd_opts_pre = parser_pre.parse_known_args()[0]
|
||||||
|
|
||||||
data_path = cmd_opts_pre.data_dir
|
data_path = cmd_opts_pre.data_dir
|
||||||
@ -20,3 +26,6 @@ data_path = cmd_opts_pre.data_dir
|
|||||||
models_path = os.path.join(data_path, "models")
|
models_path = os.path.join(data_path, "models")
|
||||||
extensions_dir = os.path.join(data_path, "extensions")
|
extensions_dir = os.path.join(data_path, "extensions")
|
||||||
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
|
||||||
|
config_states_dir = os.path.join(script_path, "config_states")
|
||||||
|
|
||||||
|
roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf')
|
||||||
|
@ -18,9 +18,14 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
|
|||||||
|
|
||||||
if extras_mode == 1:
|
if extras_mode == 1:
|
||||||
for img in image_folder:
|
for img in image_folder:
|
||||||
image = Image.open(img)
|
if isinstance(img, Image.Image):
|
||||||
|
image = img
|
||||||
|
fn = ''
|
||||||
|
else:
|
||||||
|
image = Image.open(os.path.abspath(img.name))
|
||||||
|
fn = os.path.splitext(img.orig_name)[0]
|
||||||
image_data.append(image)
|
image_data.append(image)
|
||||||
image_names.append(os.path.splitext(img.orig_name)[0])
|
image_names.append(fn)
|
||||||
elif extras_mode == 2:
|
elif extras_mode == 2:
|
||||||
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
|
||||||
assert input_dir, 'input directory not selected'
|
assert input_dir, 'input directory not selected'
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import hashlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,10 +10,10 @@ from PIL import Image, ImageFilter, ImageOps
|
|||||||
import random
|
import random
|
||||||
import cv2
|
import cv2
|
||||||
from skimage import exposure
|
from skimage import exposure
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
|
||||||
from modules.sd_hijack import model_hijack
|
from modules.sd_hijack import model_hijack
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
|
|||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
from blendmodes.blend import blendLayers, BlendType
|
from blendmodes.blend import blendLayers, BlendType
|
||||||
|
|
||||||
|
|
||||||
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
# some of those options should not be changed at all because they would break the model, so I removed them from options.
|
||||||
opt_C = 4
|
opt_C = 4
|
||||||
opt_f = 8
|
opt_f = 8
|
||||||
@ -105,7 +106,7 @@ class StableDiffusionProcessing:
|
|||||||
"""
|
"""
|
||||||
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
|
||||||
"""
|
"""
|
||||||
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
|
||||||
if sampler_index is not None:
|
if sampler_index is not None:
|
||||||
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
|
||||||
|
|
||||||
@ -140,6 +141,7 @@ class StableDiffusionProcessing:
|
|||||||
self.denoising_strength: float = denoising_strength
|
self.denoising_strength: float = denoising_strength
|
||||||
self.sampler_noise_scheduler_override = None
|
self.sampler_noise_scheduler_override = None
|
||||||
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
self.ddim_discretize = ddim_discretize or opts.ddim_discretize
|
||||||
|
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
|
||||||
self.s_churn = s_churn or opts.s_churn
|
self.s_churn = s_churn or opts.s_churn
|
||||||
self.s_tmin = s_tmin or opts.s_tmin
|
self.s_tmin = s_tmin or opts.s_tmin
|
||||||
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
|
||||||
@ -148,6 +150,8 @@ class StableDiffusionProcessing:
|
|||||||
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
self.override_settings_restore_afterwards = override_settings_restore_afterwards
|
||||||
self.is_using_inpainting_conditioning = False
|
self.is_using_inpainting_conditioning = False
|
||||||
self.disable_extra_networks = False
|
self.disable_extra_networks = False
|
||||||
|
self.token_merging_ratio = 0
|
||||||
|
self.token_merging_ratio_hr = 0
|
||||||
|
|
||||||
if not seed_enable_extras:
|
if not seed_enable_extras:
|
||||||
self.subseed = -1
|
self.subseed = -1
|
||||||
@ -162,6 +166,9 @@ class StableDiffusionProcessing:
|
|||||||
self.all_seeds = None
|
self.all_seeds = None
|
||||||
self.all_subseeds = None
|
self.all_subseeds = None
|
||||||
self.iteration = 0
|
self.iteration = 0
|
||||||
|
self.is_hr_pass = False
|
||||||
|
self.sampler = None
|
||||||
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sd_model(self):
|
def sd_model(self):
|
||||||
@ -270,6 +277,12 @@ class StableDiffusionProcessing:
|
|||||||
def close(self):
|
def close(self):
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
if for_hr:
|
||||||
|
return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
|
||||||
|
|
||||||
|
return self.token_merging_ratio or opts.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
class Processed:
|
class Processed:
|
||||||
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
|
||||||
@ -299,6 +312,8 @@ class Processed:
|
|||||||
self.styles = p.styles
|
self.styles = p.styles
|
||||||
self.job_timestamp = state.job_timestamp
|
self.job_timestamp = state.job_timestamp
|
||||||
self.clip_skip = opts.CLIP_stop_at_last_layers
|
self.clip_skip = opts.CLIP_stop_at_last_layers
|
||||||
|
self.token_merging_ratio = p.token_merging_ratio
|
||||||
|
self.token_merging_ratio_hr = p.token_merging_ratio_hr
|
||||||
|
|
||||||
self.eta = p.eta
|
self.eta = p.eta
|
||||||
self.ddim_discretize = p.ddim_discretize
|
self.ddim_discretize = p.ddim_discretize
|
||||||
@ -306,6 +321,7 @@ class Processed:
|
|||||||
self.s_tmin = p.s_tmin
|
self.s_tmin = p.s_tmin
|
||||||
self.s_tmax = p.s_tmax
|
self.s_tmax = p.s_tmax
|
||||||
self.s_noise = p.s_noise
|
self.s_noise = p.s_noise
|
||||||
|
self.s_min_uncond = p.s_min_uncond
|
||||||
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
|
||||||
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
|
||||||
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
|
||||||
@ -356,6 +372,9 @@ class Processed:
|
|||||||
def infotext(self, p: StableDiffusionProcessing, index):
|
def infotext(self, p: StableDiffusionProcessing, index):
|
||||||
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
|
||||||
|
|
||||||
|
|
||||||
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
|
||||||
def slerp(val, low, high):
|
def slerp(val, low, high):
|
||||||
@ -454,10 +473,27 @@ def fix_seed(p):
|
|||||||
p.subseed = get_fixed_seed(p.subseed)
|
p.subseed = get_fixed_seed(p.subseed)
|
||||||
|
|
||||||
|
|
||||||
|
def program_version():
|
||||||
|
import launch
|
||||||
|
|
||||||
|
res = launch.git_tag()
|
||||||
|
if res == "<none>":
|
||||||
|
res = None
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
|
||||||
index = position_in_batch + iteration * p.batch_size
|
index = position_in_batch + iteration * p.batch_size
|
||||||
|
|
||||||
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
|
||||||
|
enable_hr = getattr(p, 'enable_hr', False)
|
||||||
|
token_merging_ratio = p.get_token_merging_ratio()
|
||||||
|
token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
|
||||||
|
|
||||||
|
uses_ensd = opts.eta_noise_seed_delta != 0
|
||||||
|
if uses_ensd:
|
||||||
|
uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
|
||||||
|
|
||||||
generation_params = {
|
generation_params = {
|
||||||
"Steps": p.steps,
|
"Steps": p.steps,
|
||||||
@ -475,14 +511,19 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
|
|||||||
"Denoising strength": getattr(p, 'denoising_strength', None),
|
"Denoising strength": getattr(p, 'denoising_strength', None),
|
||||||
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
|
||||||
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
"Clip skip": None if clip_skip <= 1 else clip_skip,
|
||||||
"ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
|
"ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
|
||||||
|
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
|
||||||
|
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
|
||||||
|
"Init image hash": getattr(p, 'init_img_hash', None),
|
||||||
|
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
|
||||||
|
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
|
||||||
|
**p.extra_generation_params,
|
||||||
|
"Version": program_version() if opts.add_version_to_infotext else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
generation_params.update(p.extra_generation_params)
|
|
||||||
|
|
||||||
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
|
||||||
|
|
||||||
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
|
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
|
||||||
|
|
||||||
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
|
||||||
|
|
||||||
@ -491,6 +532,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
|
||||||
|
if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None:
|
||||||
|
p.override_settings.pop('sd_model_checkpoint', None)
|
||||||
|
sd_models.reload_model_weights()
|
||||||
|
|
||||||
for k, v in p.override_settings.items():
|
for k, v in p.override_settings.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
|
|
||||||
@ -500,15 +546,17 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
sd_models.apply_token_merging(p.sd_model, 0)
|
||||||
|
|
||||||
# restore opts to original state
|
# restore opts to original state
|
||||||
if p.override_settings_restore_afterwards:
|
if p.override_settings_restore_afterwards:
|
||||||
for k, v in stored_opts.items():
|
for k, v in stored_opts.items():
|
||||||
setattr(opts, k, v)
|
setattr(opts, k, v)
|
||||||
if k == 'sd_model_checkpoint':
|
|
||||||
sd_models.reload_model_weights()
|
|
||||||
|
|
||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
@ -639,8 +687,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
processed = Processed(p, [], p.seed, "")
|
processed = Processed(p, [], p.seed, "")
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
|
sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
|
||||||
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
|
step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
|
||||||
|
uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
|
||||||
|
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
|
||||||
|
|
||||||
if len(model_hijack.comments) > 0:
|
if len(model_hijack.comments) > 0:
|
||||||
for comment in model_hijack.comments:
|
for comment in model_hijack.comments:
|
||||||
@ -670,6 +720,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
|
||||||
|
|
||||||
for i, x_sample in enumerate(x_samples_ddim):
|
for i, x_sample in enumerate(x_samples_ddim):
|
||||||
|
p.batch_index = i
|
||||||
|
|
||||||
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
|
|
||||||
@ -706,9 +758,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
image.info["parameters"] = text
|
image.info["parameters"] = text
|
||||||
output_images.append(image)
|
output_images.append(image)
|
||||||
|
|
||||||
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:
|
if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
|
||||||
image_mask = p.mask_for_overlay.convert('RGB')
|
image_mask = p.mask_for_overlay.convert('RGB')
|
||||||
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), p.mask_for_overlay.convert('L')).convert('RGBA')
|
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
|
||||||
|
|
||||||
if opts.save_mask:
|
if opts.save_mask:
|
||||||
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
|
||||||
@ -718,7 +770,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
if opts.return_mask:
|
if opts.return_mask:
|
||||||
output_images.append(image_mask)
|
output_images.append(image_mask)
|
||||||
|
|
||||||
if opts.return_mask_composite:
|
if opts.return_mask_composite:
|
||||||
output_images.append(image_mask_composite)
|
output_images.append(image_mask_composite)
|
||||||
|
|
||||||
@ -751,7 +803,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
|
res = Processed(
|
||||||
|
p,
|
||||||
|
images_list=output_images,
|
||||||
|
seed=p.all_seeds[0],
|
||||||
|
info=infotext(),
|
||||||
|
comments="".join(f"\n\n{comment}" for comment in comments),
|
||||||
|
subseed=p.all_subseeds[0],
|
||||||
|
index_of_first_image=index_of_first_image,
|
||||||
|
infotexts=infotexts,
|
||||||
|
)
|
||||||
|
|
||||||
if p.scripts is not None:
|
if p.scripts is not None:
|
||||||
p.scripts.postprocess(p, res)
|
p.scripts.postprocess(p, res)
|
||||||
@ -871,6 +932,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
if not self.enable_hr:
|
if not self.enable_hr:
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
self.is_hr_pass = True
|
||||||
|
|
||||||
target_width = self.hr_upscale_to_x
|
target_width = self.hr_upscale_to_x
|
||||||
target_height = self.hr_upscale_to_y
|
target_height = self.hr_upscale_to_y
|
||||||
|
|
||||||
@ -938,8 +1001,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
x = None
|
x = None
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
|
||||||
|
|
||||||
|
self.is_hr_pass = False
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@ -1007,6 +1076,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.color_corrections = []
|
self.color_corrections = []
|
||||||
imgs = []
|
imgs = []
|
||||||
for img in self.init_images:
|
for img in self.init_images:
|
||||||
|
|
||||||
|
# Save init image
|
||||||
|
if opts.save_init_img:
|
||||||
|
self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
|
||||||
|
images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
|
||||||
|
|
||||||
image = images.flatten(img, opts.img2img_background_color)
|
image = images.flatten(img, opts.img2img_background_color)
|
||||||
|
|
||||||
if crop_region is None and self.resize_mode != 3:
|
if crop_region is None and self.resize_mode != 3:
|
||||||
@ -1093,3 +1168,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
def get_token_merging_ratio(self, for_hr=False):
|
||||||
|
return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio
|
||||||
|
@ -13,6 +13,8 @@ import modules.shared as shared
|
|||||||
current_task = None
|
current_task = None
|
||||||
pending_tasks = {}
|
pending_tasks = {}
|
||||||
finished_tasks = []
|
finished_tasks = []
|
||||||
|
recorded_results = []
|
||||||
|
recorded_results_limit = 2
|
||||||
|
|
||||||
|
|
||||||
def start_task(id_task):
|
def start_task(id_task):
|
||||||
@ -33,6 +35,12 @@ def finish_task(id_task):
|
|||||||
finished_tasks.pop(0)
|
finished_tasks.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def record_results(id_task, res):
|
||||||
|
recorded_results.append((id_task, res))
|
||||||
|
if len(recorded_results) > recorded_results_limit:
|
||||||
|
recorded_results.pop(0)
|
||||||
|
|
||||||
|
|
||||||
def add_task_to_queue(id_job):
|
def add_task_to_queue(id_job):
|
||||||
pending_tasks[id_job] = time.time()
|
pending_tasks[id_job] = time.time()
|
||||||
|
|
||||||
@ -87,8 +95,20 @@ def progressapi(req: ProgressRequest):
|
|||||||
image = shared.state.current_image
|
image = shared.state.current_image
|
||||||
if image is not None:
|
if image is not None:
|
||||||
buffered = io.BytesIO()
|
buffered = io.BytesIO()
|
||||||
image.save(buffered, format="png")
|
|
||||||
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
|
if opts.live_previews_image_format == "png":
|
||||||
|
# using optimize for large images takes an enormous amount of time
|
||||||
|
if max(*image.size) <= 256:
|
||||||
|
save_kwargs = {"optimize": True}
|
||||||
|
else:
|
||||||
|
save_kwargs = {"optimize": False, "compress_level": 1}
|
||||||
|
|
||||||
|
else:
|
||||||
|
save_kwargs = {}
|
||||||
|
|
||||||
|
image.save(buffered, format=opts.live_previews_image_format, **save_kwargs)
|
||||||
|
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
|
||||||
|
live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}"
|
||||||
id_live_preview = shared.state.id_live_preview
|
id_live_preview = shared.state.id_live_preview
|
||||||
else:
|
else:
|
||||||
live_preview = None
|
live_preview = None
|
||||||
@ -97,3 +117,13 @@ def progressapi(req: ProgressRequest):
|
|||||||
|
|
||||||
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo)
|
||||||
|
|
||||||
|
|
||||||
|
def restore_progress(id_task):
|
||||||
|
while id_task == current_task or id_task in pending_tasks:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
res = next(iter([x[1] for x in recorded_results if id_task == x[0]]), None)
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
|
||||||
|
return gr.update(), gr.update(), gr.update(), f"Couldn't restore progress for {id_task}: results either have been discarded or never were obtained"
|
||||||
|
@ -54,18 +54,21 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def collect_steps(steps, tree):
|
def collect_steps(steps, tree):
|
||||||
l = [steps]
|
res = [steps]
|
||||||
|
|
||||||
class CollectSteps(lark.Visitor):
|
class CollectSteps(lark.Visitor):
|
||||||
def scheduled(self, tree):
|
def scheduled(self, tree):
|
||||||
tree.children[-1] = float(tree.children[-1])
|
tree.children[-1] = float(tree.children[-1])
|
||||||
if tree.children[-1] < 1:
|
if tree.children[-1] < 1:
|
||||||
tree.children[-1] *= steps
|
tree.children[-1] *= steps
|
||||||
tree.children[-1] = min(steps, int(tree.children[-1]))
|
tree.children[-1] = min(steps, int(tree.children[-1]))
|
||||||
l.append(tree.children[-1])
|
res.append(tree.children[-1])
|
||||||
|
|
||||||
def alternate(self, tree):
|
def alternate(self, tree):
|
||||||
l.extend(range(1, steps+1))
|
res.extend(range(1, steps+1))
|
||||||
|
|
||||||
CollectSteps().visit(tree)
|
CollectSteps().visit(tree)
|
||||||
return sorted(set(l))
|
return sorted(set(res))
|
||||||
|
|
||||||
def at_step(step, tree):
|
def at_step(step, tree):
|
||||||
class AtStep(lark.Transformer):
|
class AtStep(lark.Transformer):
|
||||||
@ -92,7 +95,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
|
|||||||
def get_schedule(prompt):
|
def get_schedule(prompt):
|
||||||
try:
|
try:
|
||||||
tree = schedule_parser.parse(prompt)
|
tree = schedule_parser.parse(prompt)
|
||||||
except lark.exceptions.LarkError as e:
|
except lark.exceptions.LarkError:
|
||||||
if 0:
|
if 0:
|
||||||
import traceback
|
import traceback
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -140,7 +143,7 @@ def get_learned_conditioning(model, prompts, steps):
|
|||||||
conds = model.get_learned_conditioning(texts)
|
conds = model.get_learned_conditioning(texts)
|
||||||
|
|
||||||
cond_schedule = []
|
cond_schedule = []
|
||||||
for i, (end_at_step, text) in enumerate(prompt_schedule):
|
for i, (end_at_step, _) in enumerate(prompt_schedule):
|
||||||
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
cond_schedule.append(ScheduledPromptConditioning(end_at_step, conds[i]))
|
||||||
|
|
||||||
cache[prompt] = cond_schedule
|
cache[prompt] = cond_schedule
|
||||||
@ -216,8 +219,8 @@ def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_s
|
|||||||
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)
|
||||||
for i, cond_schedule in enumerate(c):
|
for i, cond_schedule in enumerate(c):
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(cond_schedule):
|
for current, entry in enumerate(cond_schedule):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
res[i] = cond_schedule[target_index].cond
|
res[i] = cond_schedule[target_index].cond
|
||||||
@ -231,13 +234,13 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
|
|||||||
tensors = []
|
tensors = []
|
||||||
conds_list = []
|
conds_list = []
|
||||||
|
|
||||||
for batch_no, composable_prompts in enumerate(c.batch):
|
for composable_prompts in c.batch:
|
||||||
conds_for_batch = []
|
conds_for_batch = []
|
||||||
|
|
||||||
for cond_index, composable_prompt in enumerate(composable_prompts):
|
for composable_prompt in composable_prompts:
|
||||||
target_index = 0
|
target_index = 0
|
||||||
for current, (end_at, cond) in enumerate(composable_prompt.schedules):
|
for current, entry in enumerate(composable_prompt.schedules):
|
||||||
if current_step <= end_at:
|
if current_step <= entry.end_at_step:
|
||||||
target_index = current
|
target_index = current
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from realesrgan import RealESRGANer
|
|||||||
|
|
||||||
from modules.upscaler import Upscaler, UpscalerData
|
from modules.upscaler import Upscaler, UpscalerData
|
||||||
from modules.shared import cmd_opts, opts
|
from modules.shared import cmd_opts, opts
|
||||||
|
from modules import modelloader
|
||||||
|
|
||||||
class UpscalerRealESRGAN(Upscaler):
|
class UpscalerRealESRGAN(Upscaler):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
@ -17,13 +17,21 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
self.user_path = path
|
self.user_path = path
|
||||||
super().__init__()
|
super().__init__()
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer # noqa: F401
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401
|
||||||
self.enable = True
|
self.enable = True
|
||||||
self.scalers = []
|
self.scalers = []
|
||||||
scalers = self.load_models(path)
|
scalers = self.load_models(path)
|
||||||
|
|
||||||
|
local_model_paths = self.find_models(ext_filter=[".pth"])
|
||||||
for scaler in scalers:
|
for scaler in scalers:
|
||||||
|
if scaler.local_data_path.startswith("http"):
|
||||||
|
filename = modelloader.friendly_name(scaler.local_data_path)
|
||||||
|
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
|
||||||
|
if local_model_candidates:
|
||||||
|
scaler.local_data_path = local_model_candidates[0]
|
||||||
|
|
||||||
if scaler.name in opts.realesrgan_enabled_models:
|
if scaler.name in opts.realesrgan_enabled_models:
|
||||||
self.scalers.append(scaler)
|
self.scalers.append(scaler)
|
||||||
|
|
||||||
@ -39,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
|
|
||||||
info = self.load_model(path)
|
info = self.load_model(path)
|
||||||
if not os.path.exists(info.local_data_path):
|
if not os.path.exists(info.local_data_path):
|
||||||
print("Unable to load RealESRGAN model: %s" % info.name)
|
print(f"Unable to load RealESRGAN model: {info.name}")
|
||||||
return img
|
return img
|
||||||
|
|
||||||
upsampler = RealESRGANer(
|
upsampler = RealESRGANer(
|
||||||
@ -64,7 +72,9 @@ class UpscalerRealESRGAN(Upscaler):
|
|||||||
print(f"Unable to find model info: {path}")
|
print(f"Unable to find model info: {path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
if info.local_data_path.startswith("http"):
|
||||||
|
info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
|
||||||
|
|
||||||
return info
|
return info
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
|
||||||
@ -124,6 +134,6 @@ def get_realesrgan_models(scaler):
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
return models
|
return models
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
print("Error making Real-ESRGAN models list:", file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
# this code is adapted from the script contributed by anon from /h/
|
# this code is adapted from the script contributed by anon from /h/
|
||||||
|
|
||||||
import io
|
|
||||||
import pickle
|
import pickle
|
||||||
import collections
|
import collections
|
||||||
import sys
|
import sys
|
||||||
@ -12,11 +11,9 @@ import _codecs
|
|||||||
import zipfile
|
import zipfile
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
# PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
|
||||||
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
|
||||||
|
|
||||||
|
|
||||||
def encode(*args):
|
def encode(*args):
|
||||||
out = _codecs.encode(*args)
|
out = _codecs.encode(*args)
|
||||||
return out
|
return out
|
||||||
@ -27,7 +24,11 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
|
|
||||||
def persistent_load(self, saved_id):
|
def persistent_load(self, saved_id):
|
||||||
assert saved_id[0] == 'storage'
|
assert saved_id[0] == 'storage'
|
||||||
return TypedStorage()
|
|
||||||
|
try:
|
||||||
|
return TypedStorage(_internal=True)
|
||||||
|
except TypeError:
|
||||||
|
return TypedStorage() # PyTorch before 2.0 does not have the _internal argument
|
||||||
|
|
||||||
def find_class(self, module, name):
|
def find_class(self, module, name):
|
||||||
if self.extra_handler is not None:
|
if self.extra_handler is not None:
|
||||||
@ -39,7 +40,7 @@ class RestrictedUnpickler(pickle.Unpickler):
|
|||||||
return getattr(collections, name)
|
return getattr(collections, name)
|
||||||
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
|
||||||
return getattr(torch._utils, name)
|
return getattr(torch._utils, name)
|
||||||
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
|
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32', 'BFloat16Storage']:
|
||||||
return getattr(torch, name)
|
return getattr(torch, name)
|
||||||
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
|
||||||
return getattr(torch.nn.modules.container, name)
|
return getattr(torch.nn.modules.container, name)
|
||||||
@ -94,16 +95,16 @@ def check_pt(filename, extra_handler):
|
|||||||
|
|
||||||
except zipfile.BadZipfile:
|
except zipfile.BadZipfile:
|
||||||
|
|
||||||
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
|
# if it's not a zip file, it's an old pytorch format, with five objects written to pickle
|
||||||
with open(filename, "rb") as file:
|
with open(filename, "rb") as file:
|
||||||
unpickler = RestrictedUnpickler(file)
|
unpickler = RestrictedUnpickler(file)
|
||||||
unpickler.extra_handler = extra_handler
|
unpickler.extra_handler = extra_handler
|
||||||
for i in range(5):
|
for _ in range(5):
|
||||||
unpickler.load()
|
unpickler.load()
|
||||||
|
|
||||||
|
|
||||||
def load(filename, *args, **kwargs):
|
def load(filename, *args, **kwargs):
|
||||||
return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
|
return load_with_extra(filename, *args, extra_handler=global_extra_handler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
|
||||||
|
@ -32,27 +32,42 @@ class CFGDenoiserParams:
|
|||||||
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
self.image_cond = image_cond
|
self.image_cond = image_cond
|
||||||
"""Conditioning image"""
|
"""Conditioning image"""
|
||||||
|
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
"""Current sigma noise step value"""
|
"""Current sigma noise step value"""
|
||||||
|
|
||||||
self.sampling_step = sampling_step
|
self.sampling_step = sampling_step
|
||||||
"""Current Sampling step number"""
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
self.total_sampling_steps = total_sampling_steps
|
self.total_sampling_steps = total_sampling_steps
|
||||||
"""Total number of sampling steps planned"""
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
self.text_cond = text_cond
|
self.text_cond = text_cond
|
||||||
""" Encoder hidden states of text conditioning from prompt"""
|
""" Encoder hidden states of text conditioning from prompt"""
|
||||||
|
|
||||||
self.text_uncond = text_uncond
|
self.text_uncond = text_uncond
|
||||||
""" Encoder hidden states of text conditioning from negative prompt"""
|
""" Encoder hidden states of text conditioning from negative prompt"""
|
||||||
|
|
||||||
|
|
||||||
class CFGDenoisedParams:
|
class CFGDenoisedParams:
|
||||||
|
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
|
||||||
|
self.x = x
|
||||||
|
"""Latent image representation in the process of being denoised"""
|
||||||
|
|
||||||
|
self.sampling_step = sampling_step
|
||||||
|
"""Current Sampling step number"""
|
||||||
|
|
||||||
|
self.total_sampling_steps = total_sampling_steps
|
||||||
|
"""Total number of sampling steps planned"""
|
||||||
|
|
||||||
|
self.inner_model = inner_model
|
||||||
|
"""Inner model reference used for denoising"""
|
||||||
|
|
||||||
|
|
||||||
|
class AfterCFGCallbackParams:
|
||||||
def __init__(self, x, sampling_step, total_sampling_steps):
|
def __init__(self, x, sampling_step, total_sampling_steps):
|
||||||
self.x = x
|
self.x = x
|
||||||
"""Latent image representation in the process of being denoised"""
|
"""Latent image representation in the process of being denoised"""
|
||||||
@ -87,12 +102,14 @@ callback_map = dict(
|
|||||||
callbacks_image_saved=[],
|
callbacks_image_saved=[],
|
||||||
callbacks_cfg_denoiser=[],
|
callbacks_cfg_denoiser=[],
|
||||||
callbacks_cfg_denoised=[],
|
callbacks_cfg_denoised=[],
|
||||||
|
callbacks_cfg_after_cfg=[],
|
||||||
callbacks_before_component=[],
|
callbacks_before_component=[],
|
||||||
callbacks_after_component=[],
|
callbacks_after_component=[],
|
||||||
callbacks_image_grid=[],
|
callbacks_image_grid=[],
|
||||||
callbacks_infotext_pasted=[],
|
callbacks_infotext_pasted=[],
|
||||||
callbacks_script_unloaded=[],
|
callbacks_script_unloaded=[],
|
||||||
callbacks_before_ui=[],
|
callbacks_before_ui=[],
|
||||||
|
callbacks_on_reload=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -109,6 +126,14 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI):
|
|||||||
report_exception(c, 'app_started_callback')
|
report_exception(c, 'app_started_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def app_reload_callback():
|
||||||
|
for c in callback_map['callbacks_on_reload']:
|
||||||
|
try:
|
||||||
|
c.callback()
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'callbacks_on_reload')
|
||||||
|
|
||||||
|
|
||||||
def model_loaded_callback(sd_model):
|
def model_loaded_callback(sd_model):
|
||||||
for c in callback_map['callbacks_model_loaded']:
|
for c in callback_map['callbacks_model_loaded']:
|
||||||
try:
|
try:
|
||||||
@ -177,6 +202,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
|
|||||||
report_exception(c, 'cfg_denoised_callback')
|
report_exception(c, 'cfg_denoised_callback')
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
|
||||||
|
for c in callback_map['callbacks_cfg_after_cfg']:
|
||||||
|
try:
|
||||||
|
c.callback(params)
|
||||||
|
except Exception:
|
||||||
|
report_exception(c, 'cfg_after_cfg_callback')
|
||||||
|
|
||||||
|
|
||||||
def before_component_callback(component, **kwargs):
|
def before_component_callback(component, **kwargs):
|
||||||
for c in callback_map['callbacks_before_component']:
|
for c in callback_map['callbacks_before_component']:
|
||||||
try:
|
try:
|
||||||
@ -231,7 +264,7 @@ def add_callback(callbacks, fun):
|
|||||||
|
|
||||||
callbacks.append(ScriptCallback(filename, fun))
|
callbacks.append(ScriptCallback(filename, fun))
|
||||||
|
|
||||||
|
|
||||||
def remove_current_script_callbacks():
|
def remove_current_script_callbacks():
|
||||||
stack = [x for x in inspect.stack() if x.filename != __file__]
|
stack = [x for x in inspect.stack() if x.filename != __file__]
|
||||||
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
|
||||||
@ -254,6 +287,11 @@ def on_app_started(callback):
|
|||||||
add_callback(callback_map['callbacks_app_started'], callback)
|
add_callback(callback_map['callbacks_app_started'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_before_reload(callback):
|
||||||
|
"""register a function to be called just before the server reloads."""
|
||||||
|
add_callback(callback_map['callbacks_on_reload'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_model_loaded(callback):
|
def on_model_loaded(callback):
|
||||||
"""register a function to be called when the stable diffusion model is created; the model is
|
"""register a function to be called when the stable diffusion model is created; the model is
|
||||||
passed as an argument; this function is also called when the script is reloaded. """
|
passed as an argument; this function is also called when the script is reloaded. """
|
||||||
@ -318,6 +356,14 @@ def on_cfg_denoised(callback):
|
|||||||
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
add_callback(callback_map['callbacks_cfg_denoised'], callback)
|
||||||
|
|
||||||
|
|
||||||
|
def on_cfg_after_cfg(callback):
|
||||||
|
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
|
||||||
|
The callback is called with one argument:
|
||||||
|
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
|
||||||
|
"""
|
||||||
|
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
|
||||||
|
|
||||||
|
|
||||||
def on_before_component(callback):
|
def on_before_component(callback):
|
||||||
"""register a function to be called before a component is created.
|
"""register a function to be called before a component is created.
|
||||||
The callback is called with arguments:
|
The callback is called with arguments:
|
||||||
|
@ -2,7 +2,6 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from types import ModuleType
|
|
||||||
|
|
||||||
|
|
||||||
def load_module(path):
|
def load_module(path):
|
||||||
|
@ -17,6 +17,9 @@ class PostprocessImageArgs:
|
|||||||
|
|
||||||
|
|
||||||
class Script:
|
class Script:
|
||||||
|
name = None
|
||||||
|
"""script's internal name derived from title"""
|
||||||
|
|
||||||
filename = None
|
filename = None
|
||||||
args_from = None
|
args_from = None
|
||||||
args_to = None
|
args_to = None
|
||||||
@ -25,8 +28,8 @@ class Script:
|
|||||||
is_txt2img = False
|
is_txt2img = False
|
||||||
is_img2img = False
|
is_img2img = False
|
||||||
|
|
||||||
"""A gr.Group component that has all script's UI inside it"""
|
|
||||||
group = None
|
group = None
|
||||||
|
"""A gr.Group component that has all script's UI inside it"""
|
||||||
|
|
||||||
infotext_fields = None
|
infotext_fields = None
|
||||||
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
|
||||||
@ -38,6 +41,9 @@ class Script:
|
|||||||
various "Send to <X>" buttons when clicked
|
various "Send to <X>" buttons when clicked
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
api_info = None
|
||||||
|
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
|
||||||
|
|
||||||
def title(self):
|
def title(self):
|
||||||
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
|
||||||
|
|
||||||
@ -163,7 +169,8 @@ class Script:
|
|||||||
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
|
||||||
|
|
||||||
need_tabname = self.show(True) == self.show(False)
|
need_tabname = self.show(True) == self.show(False)
|
||||||
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
|
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
|
||||||
|
tabname = f"{tabkind}_" if need_tabname else ""
|
||||||
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
|
||||||
|
|
||||||
return f'script_{tabname}{title}_{item_id}'
|
return f'script_{tabname}{title}_{item_id}'
|
||||||
@ -230,7 +237,7 @@ def load_scripts():
|
|||||||
syspath = sys.path
|
syspath = sys.path
|
||||||
|
|
||||||
def register_scripts_from_module(module):
|
def register_scripts_from_module(module):
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) != type:
|
if type(script_class) != type:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -294,9 +301,9 @@ class ScriptRunner:
|
|||||||
|
|
||||||
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
|
for script_data in auto_processing_scripts + scripts_data:
|
||||||
script = script_class()
|
script = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
script.is_txt2img = not is_img2img
|
script.is_txt2img = not is_img2img
|
||||||
script.is_img2img = is_img2img
|
script.is_img2img = is_img2img
|
||||||
|
|
||||||
@ -312,6 +319,8 @@ class ScriptRunner:
|
|||||||
self.selectable_scripts.append(script)
|
self.selectable_scripts.append(script)
|
||||||
|
|
||||||
def setup_ui(self):
|
def setup_ui(self):
|
||||||
|
import modules.api.models as api_models
|
||||||
|
|
||||||
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
|
||||||
|
|
||||||
inputs = [None]
|
inputs = [None]
|
||||||
@ -326,9 +335,28 @@ class ScriptRunner:
|
|||||||
if controls is None:
|
if controls is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower()
|
||||||
|
api_args = []
|
||||||
|
|
||||||
for control in controls:
|
for control in controls:
|
||||||
control.custom_script_source = os.path.basename(script.filename)
|
control.custom_script_source = os.path.basename(script.filename)
|
||||||
|
|
||||||
|
arg_info = api_models.ScriptArg(label=control.label or "")
|
||||||
|
|
||||||
|
for field in ("value", "minimum", "maximum", "step", "choices"):
|
||||||
|
v = getattr(control, field, None)
|
||||||
|
if v is not None:
|
||||||
|
setattr(arg_info, field, v)
|
||||||
|
|
||||||
|
api_args.append(arg_info)
|
||||||
|
|
||||||
|
script.api_info = api_models.ScriptInfo(
|
||||||
|
name=script.name,
|
||||||
|
is_img2img=script.is_img2img,
|
||||||
|
is_alwayson=script.alwayson,
|
||||||
|
args=api_args,
|
||||||
|
)
|
||||||
|
|
||||||
if script.infotext_fields is not None:
|
if script.infotext_fields is not None:
|
||||||
self.infotext_fields += script.infotext_fields
|
self.infotext_fields += script.infotext_fields
|
||||||
|
|
||||||
@ -491,7 +519,7 @@ class ScriptRunner:
|
|||||||
module = script_loading.load_module(script.filename)
|
module = script_loading.load_module(script.filename)
|
||||||
cache[filename] = module
|
cache[filename] = module
|
||||||
|
|
||||||
for key, script_class in module.__dict__.items():
|
for script_class in module.__dict__.values():
|
||||||
if type(script_class) == type and issubclass(script_class, Script):
|
if type(script_class) == type and issubclass(script_class, Script):
|
||||||
self.scripts[si] = script_class()
|
self.scripts[si] = script_class()
|
||||||
self.scripts[si].filename = filename
|
self.scripts[si].filename = filename
|
||||||
@ -526,7 +554,7 @@ def add_classes_to_gradio_component(comp):
|
|||||||
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
|
||||||
"""
|
"""
|
||||||
|
|
||||||
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
|
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
|
||||||
|
|
||||||
if getattr(comp, 'multiselect', False):
|
if getattr(comp, 'multiselect', False):
|
||||||
comp.elem_classes.append('multiselect')
|
comp.elem_classes.append('multiselect')
|
||||||
|
@ -17,7 +17,7 @@ class ScriptPostprocessingForMainUI(scripts.Script):
|
|||||||
return self.postprocessing_controls.values()
|
return self.postprocessing_controls.values()
|
||||||
|
|
||||||
def postprocess_image(self, p, script_pp, *args):
|
def postprocess_image(self, p, script_pp, *args):
|
||||||
args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)}
|
args_dict = dict(zip(self.postprocessing_controls, args))
|
||||||
|
|
||||||
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
pp = scripts_postprocessing.PostprocessedImage(script_pp.image)
|
||||||
pp.info = {}
|
pp.info = {}
|
||||||
|
@ -66,9 +66,9 @@ class ScriptPostprocessingRunner:
|
|||||||
def initialize_scripts(self, scripts_data):
|
def initialize_scripts(self, scripts_data):
|
||||||
self.scripts = []
|
self.scripts = []
|
||||||
|
|
||||||
for script_class, path, basedir, script_module in scripts_data:
|
for script_data in scripts_data:
|
||||||
script: ScriptPostprocessing = script_class()
|
script: ScriptPostprocessing = script_data.script_class()
|
||||||
script.filename = path
|
script.filename = script_data.path
|
||||||
|
|
||||||
if script.name == "Simple Upscale":
|
if script.name == "Simple Upscale":
|
||||||
continue
|
continue
|
||||||
@ -124,7 +124,7 @@ class ScriptPostprocessingRunner:
|
|||||||
script_args = args[script.args_from:script.args_to]
|
script_args = args[script.args_from:script.args_to]
|
||||||
|
|
||||||
process_args = {}
|
process_args = {}
|
||||||
for (name, component), value in zip(script.controls.items(), script_args):
|
for (name, _component), value in zip(script.controls.items(), script_args):
|
||||||
process_args[name] = value
|
process_args[name] = value
|
||||||
|
|
||||||
script.process(pp, **process_args)
|
script.process(pp, **process_args)
|
||||||
|
@ -61,7 +61,7 @@ class DisableInitialization:
|
|||||||
if res is None:
|
if res is None:
|
||||||
res = original(url, *args, local_files_only=False, **kwargs)
|
res = original(url, *args, local_files_only=False, **kwargs)
|
||||||
return res
|
return res
|
||||||
except Exception as e:
|
except Exception:
|
||||||
return original(url, *args, local_files_only=False, **kwargs)
|
return original(url, *args, local_files_only=False, **kwargs)
|
||||||
|
|
||||||
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):
|
||||||
|
@ -3,7 +3,7 @@ from torch.nn.functional import silu
|
|||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
import modules.textual_inversion.textual_inversion
|
import modules.textual_inversion.textual_inversion
|
||||||
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
|
from modules import devices, sd_hijack_optimizations, shared
|
||||||
from modules.hypernetworks import hypernetwork
|
from modules.hypernetworks import hypernetwork
|
||||||
from modules.shared import cmd_opts
|
from modules.shared import cmd_opts
|
||||||
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
|
||||||
@ -34,10 +34,10 @@ def apply_optimizations():
|
|||||||
|
|
||||||
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
ldm.modules.diffusionmodules.model.nonlinearity = silu
|
||||||
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
|
||||||
|
|
||||||
optimization_method = None
|
optimization_method = None
|
||||||
|
|
||||||
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
|
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp
|
||||||
|
|
||||||
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
|
||||||
print("Applying xformers cross attention optimization.")
|
print("Applying xformers cross attention optimization.")
|
||||||
@ -92,12 +92,12 @@ def fix_checkpoint():
|
|||||||
def weighted_loss(sd_model, pred, target, mean=True):
|
def weighted_loss(sd_model, pred, target, mean=True):
|
||||||
#Calculate the weight normally, but ignore the mean
|
#Calculate the weight normally, but ignore the mean
|
||||||
loss = sd_model._old_get_loss(pred, target, mean=False)
|
loss = sd_model._old_get_loss(pred, target, mean=False)
|
||||||
|
|
||||||
#Check if we have weights available
|
#Check if we have weights available
|
||||||
weight = getattr(sd_model, '_custom_loss_weight', None)
|
weight = getattr(sd_model, '_custom_loss_weight', None)
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
loss *= weight
|
loss *= weight
|
||||||
|
|
||||||
#Return the loss, as mean if specified
|
#Return the loss, as mean if specified
|
||||||
return loss.mean() if mean else loss
|
return loss.mean() if mean else loss
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Temporarily append weights to a place accessible during loss calc
|
#Temporarily append weights to a place accessible during loss calc
|
||||||
sd_model._custom_loss_weight = w
|
sd_model._custom_loss_weight = w
|
||||||
|
|
||||||
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
#Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely
|
||||||
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
#Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set
|
||||||
if not hasattr(sd_model, '_old_get_loss'):
|
if not hasattr(sd_model, '_old_get_loss'):
|
||||||
@ -118,9 +118,9 @@ def weighted_forward(sd_model, x, c, w, *args, **kwargs):
|
|||||||
try:
|
try:
|
||||||
#Delete temporary weights if appended
|
#Delete temporary weights if appended
|
||||||
del sd_model._custom_loss_weight
|
del sd_model._custom_loss_weight
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
#If we have an old loss function, reset the loss function to the original one
|
#If we have an old loss function, reset the loss function to the original one
|
||||||
if hasattr(sd_model, '_old_get_loss'):
|
if hasattr(sd_model, '_old_get_loss'):
|
||||||
sd_model.get_loss = sd_model._old_get_loss
|
sd_model.get_loss = sd_model._old_get_loss
|
||||||
@ -133,7 +133,7 @@ def apply_weighted_forward(sd_model):
|
|||||||
def undo_weighted_forward(sd_model):
|
def undo_weighted_forward(sd_model):
|
||||||
try:
|
try:
|
||||||
del sd_model.weighted_forward
|
del sd_model.weighted_forward
|
||||||
except AttributeError as e:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +184,7 @@ class StableDiffusionModelHijack:
|
|||||||
|
|
||||||
def undo_hijack(self, m):
|
def undo_hijack(self, m):
|
||||||
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
|
|
||||||
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
|
||||||
m.cond_stage_model = m.cond_stage_model.wrapped
|
m.cond_stage_model = m.cond_stage_model.wrapped
|
||||||
@ -216,6 +216,9 @@ class StableDiffusionModelHijack:
|
|||||||
self.comments = []
|
self.comments = []
|
||||||
|
|
||||||
def get_prompt_lengths(self, text):
|
def get_prompt_lengths(self, text):
|
||||||
|
if self.clip is None:
|
||||||
|
return "-", "-"
|
||||||
|
|
||||||
_, token_count = self.clip.process_texts([text])
|
_, token_count = self.clip.process_texts([text])
|
||||||
|
|
||||||
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
return token_count, self.clip.get_target_prompt_token_count(token_count)
|
||||||
|
@ -223,7 +223,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
|||||||
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
self.hijack.fixes = [x.fixes for x in batch_chunk]
|
||||||
|
|
||||||
for fixes in self.hijack.fixes:
|
for fixes in self.hijack.fixes:
|
||||||
for position, embedding in fixes:
|
for _position, embedding in fixes:
|
||||||
used_embeddings[embedding.name] = embedding
|
used_embeddings[embedding.name] = embedding
|
||||||
|
|
||||||
z = self.process_tokens(tokens, multipliers)
|
z = self.process_tokens(tokens, multipliers)
|
||||||
|
@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
|
|||||||
self.hijack.comments += hijack_comments
|
self.hijack.comments += hijack_comments
|
||||||
|
|
||||||
if len(used_custom_terms) > 0:
|
if len(used_custom_terms) > 0:
|
||||||
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
|
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
|
||||||
|
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
|
||||||
|
|
||||||
self.hijack.fixes = hijack_fixes
|
self.hijack.fixes = hijack_fixes
|
||||||
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
return self.process_tokens(remade_batch_tokens, batch_multipliers)
|
||||||
|
@ -1,16 +1,10 @@
|
|||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from einops import repeat
|
|
||||||
from omegaconf import ListConfig
|
|
||||||
|
|
||||||
import ldm.models.diffusion.ddpm
|
import ldm.models.diffusion.ddpm
|
||||||
import ldm.models.diffusion.ddim
|
import ldm.models.diffusion.ddim
|
||||||
import ldm.models.diffusion.plms
|
import ldm.models.diffusion.plms
|
||||||
|
|
||||||
from ldm.models.diffusion.ddpm import LatentDiffusion
|
from ldm.models.diffusion.ddim import noise_like
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +23,7 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
|
|||||||
|
|
||||||
if isinstance(c, dict):
|
if isinstance(c, dict):
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
assert isinstance(unconditional_conditioning, dict)
|
||||||
c_in = dict()
|
c_in = {}
|
||||||
for k in c:
|
for k in c:
|
||||||
if isinstance(c[k], list):
|
if isinstance(c[k], list):
|
||||||
c_in[k] = [
|
c_in[k] = [
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
import collections
|
|
||||||
import os.path
|
import os.path
|
||||||
import sys
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
|
|
||||||
def should_hijack_ip2p(checkpoint_info):
|
def should_hijack_ip2p(checkpoint_info):
|
||||||
from modules import sd_models_config
|
from modules import sd_models_config
|
||||||
@ -10,4 +7,4 @@ def should_hijack_ip2p(checkpoint_info):
|
|||||||
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
|
||||||
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower()
|
||||||
|
|
||||||
return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename
|
return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename
|
||||||
|
@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
del context, context_k, context_v, x
|
del context, context_k, context_v, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -62,10 +62,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
|
|||||||
end = i + 2
|
end = i + 2
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||||
s1 *= self.scale
|
s1 *= self.scale
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1)
|
s2 = s1.softmax(dim=-1)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||||
del s2
|
del s2
|
||||||
del q, k, v
|
del q, k, v
|
||||||
@ -95,43 +95,43 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k_in = k_in * self.scale
|
k_in = k_in * self.scale
|
||||||
|
|
||||||
del context, x
|
del context, x
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||||
|
|
||||||
mem_free_total = get_available_vram()
|
mem_free_total = get_available_vram()
|
||||||
|
|
||||||
gb = 1024 ** 3
|
gb = 1024 ** 3
|
||||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
|
||||||
modifier = 3 if q.element_size() == 2 else 2.5
|
modifier = 3 if q.element_size() == 2 else 2.5
|
||||||
mem_required = tensor_size * modifier
|
mem_required = tensor_size * modifier
|
||||||
steps = 1
|
steps = 1
|
||||||
|
|
||||||
if mem_required > mem_free_total:
|
if mem_required > mem_free_total:
|
||||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||||
|
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
|
||||||
|
|
||||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||||
del s1
|
del s1
|
||||||
|
|
||||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||||
del s2
|
del s2
|
||||||
|
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
r1 = r1.to(dtype)
|
r1 = r1.to(dtype)
|
||||||
@ -228,8 +228,8 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
|
|||||||
|
|
||||||
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
with devices.without_autocast(disable=not shared.opts.upcast_attn):
|
||||||
k = k * self.scale
|
k = k * self.scale
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
|
||||||
r = einsum_op(q, k, v)
|
r = einsum_op(q, k, v)
|
||||||
r = r.to(dtype)
|
r = r.to(dtype)
|
||||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||||
@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
|
|||||||
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
|
||||||
|
|
||||||
|
if q.device.type == 'mps':
|
||||||
|
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -293,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
|
|||||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||||
# i.e. send it down the unchunked fast-path
|
# i.e. send it down the unchunked fast-path
|
||||||
query_chunk_size = q_tokens
|
|
||||||
kv_chunk_size = k_tokens
|
kv_chunk_size = k_tokens
|
||||||
|
|
||||||
with devices.without_autocast(disable=q.dtype == v.dtype):
|
with devices.without_autocast(disable=q.dtype == v.dtype):
|
||||||
@ -332,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
|
|||||||
k_in = self.to_k(context_k)
|
k_in = self.to_k(context_k)
|
||||||
v_in = self.to_v(context_v)
|
v_in = self.to_v(context_v)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
|
q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -367,7 +369,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
|
|||||||
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
del q_in, k_in, v_in
|
del q_in, k_in, v_in
|
||||||
|
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
@ -449,7 +451,7 @@ def cross_attention_attnblock_forward(self, x):
|
|||||||
h3 += x
|
h3 += x
|
||||||
|
|
||||||
return h3
|
return h3
|
||||||
|
|
||||||
def xformers_attnblock_forward(self, x):
|
def xformers_attnblock_forward(self, x):
|
||||||
try:
|
try:
|
||||||
h_ = x
|
h_ = x
|
||||||
@ -458,7 +460,7 @@ def xformers_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -480,7 +482,7 @@ def sdp_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
dtype = q.dtype
|
dtype = q.dtype
|
||||||
if shared.opts.upcast_attn:
|
if shared.opts.upcast_attn:
|
||||||
q, k = q.float(), k.float()
|
q, k = q.float(), k.float()
|
||||||
@ -504,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
|
|||||||
k = self.k(h_)
|
k = self.k(h_)
|
||||||
v = self.v(h_)
|
v = self.v(h_)
|
||||||
b, c, h, w = q.shape
|
b, c, h, w = q.shape
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
|
q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
|
||||||
q = q.contiguous()
|
q = q.contiguous()
|
||||||
k = k.contiguous()
|
k = k.contiguous()
|
||||||
v = v.contiguous()
|
v = v.contiguous()
|
||||||
|
@ -18,7 +18,7 @@ class TorchHijackForUnet:
|
|||||||
if hasattr(torch, item):
|
if hasattr(torch, item):
|
||||||
return getattr(torch, item)
|
return getattr(torch, item)
|
||||||
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
|
||||||
|
|
||||||
def cat(self, tensors, *args, **kwargs):
|
def cat(self, tensors, *args, **kwargs):
|
||||||
if len(tensors) == 2:
|
if len(tensors) == 2:
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import open_clip.tokenizer
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules import sd_hijack_clip, devices
|
from modules import sd_hijack_clip, devices
|
||||||
from modules.shared import opts
|
|
||||||
|
|
||||||
|
|
||||||
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
|
||||||
|
@ -2,6 +2,8 @@ import collections
|
|||||||
import os.path
|
import os.path
|
||||||
import sys
|
import sys
|
||||||
import gc
|
import gc
|
||||||
|
import threading
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -13,9 +15,9 @@ import ldm.modules.midas as midas
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
|
||||||
from modules.paths import models_path
|
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
import tomesd
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -45,20 +47,29 @@ class CheckpointInfo:
|
|||||||
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
|
||||||
self.hash = model_hash(filename)
|
self.hash = model_hash(filename)
|
||||||
|
|
||||||
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
|
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
|
||||||
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
||||||
|
|
||||||
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
|
||||||
|
|
||||||
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
|
||||||
|
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
_, ext = os.path.splitext(self.filename)
|
||||||
|
if ext.lower() == ".safetensors":
|
||||||
|
try:
|
||||||
|
self.metadata = read_metadata_from_safetensors(filename)
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, f"reading checkpoint metadata: {filename}")
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
checkpoints_list[self.title] = self
|
checkpoints_list[self.title] = self
|
||||||
for id in self.ids:
|
for id in self.ids:
|
||||||
checkpoint_alisases[id] = self
|
checkpoint_alisases[id] = self
|
||||||
|
|
||||||
def calculate_shorthash(self):
|
def calculate_shorthash(self):
|
||||||
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
|
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
||||||
if self.sha256 is None:
|
if self.sha256 is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -76,8 +87,7 @@ class CheckpointInfo:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
|
||||||
|
from transformers import logging, CLIPModel # noqa: F401
|
||||||
from transformers import logging, CLIPModel
|
|
||||||
|
|
||||||
logging.set_verbosity_error()
|
logging.set_verbosity_error()
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -156,7 +166,7 @@ def model_hash(filename):
|
|||||||
|
|
||||||
def select_checkpoint():
|
def select_checkpoint():
|
||||||
model_checkpoint = shared.opts.sd_model_checkpoint
|
model_checkpoint = shared.opts.sd_model_checkpoint
|
||||||
|
|
||||||
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
|
||||||
if checkpoint_info is not None:
|
if checkpoint_info is not None:
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
@ -228,7 +238,7 @@ def read_metadata_from_safetensors(filename):
|
|||||||
if isinstance(v, str) and v[0:1] == '{':
|
if isinstance(v, str) and v[0:1] == '{':
|
||||||
try:
|
try:
|
||||||
res[k] = json.loads(v)
|
res[k] = json.loads(v)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return res
|
return res
|
||||||
@ -363,7 +373,7 @@ def enable_midas_autodownload():
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
if not os.path.exists(midas_path):
|
if not os.path.exists(midas_path):
|
||||||
mkdir(midas_path)
|
mkdir(midas_path)
|
||||||
|
|
||||||
print(f"Downloading midas model weights for {model_type} to {path}")
|
print(f"Downloading midas model weights for {model_type} to {path}")
|
||||||
request.urlretrieve(midas_urls[model_type], path)
|
request.urlretrieve(midas_urls[model_type], path)
|
||||||
print(f"{model_type} downloaded")
|
print(f"{model_type} downloaded")
|
||||||
@ -395,13 +405,42 @@ def repair_config(sd_config):
|
|||||||
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
|
||||||
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
|
||||||
|
|
||||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
|
|
||||||
|
class SdModelData:
|
||||||
|
def __init__(self):
|
||||||
|
self.sd_model = None
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def get_sd_model(self):
|
||||||
|
if self.sd_model is None:
|
||||||
|
with self.lock:
|
||||||
|
if self.sd_model is not None:
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
load_model()
|
||||||
|
except Exception as e:
|
||||||
|
errors.display(e, "loading stable diffusion model")
|
||||||
|
print("", file=sys.stderr)
|
||||||
|
print("Stable diffusion model failed to load", file=sys.stderr)
|
||||||
|
self.sd_model = None
|
||||||
|
|
||||||
|
return self.sd_model
|
||||||
|
|
||||||
|
def set_sd_model(self, v):
|
||||||
|
self.sd_model = v
|
||||||
|
|
||||||
|
|
||||||
|
model_data = SdModelData()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||||
from modules import lowvram, sd_hijack
|
from modules import lowvram, sd_hijack
|
||||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
shared.sd_model = None
|
model_data.sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
@ -430,7 +469,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
try:
|
try:
|
||||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
|
||||||
sd_model = instantiate_from_config(sd_config.model)
|
sd_model = instantiate_from_config(sd_config.model)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if sd_model is None:
|
if sd_model is None:
|
||||||
@ -455,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
|
|||||||
timer.record("hijack")
|
timer.record("hijack")
|
||||||
|
|
||||||
sd_model.eval()
|
sd_model.eval()
|
||||||
shared.sd_model = sd_model
|
model_data.sd_model = sd_model
|
||||||
|
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
|
||||||
|
|
||||||
@ -475,7 +514,7 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
checkpoint_info = info or select_checkpoint()
|
checkpoint_info = info or select_checkpoint()
|
||||||
|
|
||||||
if not sd_model:
|
if not sd_model:
|
||||||
sd_model = shared.sd_model
|
sd_model = model_data.sd_model
|
||||||
|
|
||||||
if sd_model is None: # previous model load failed
|
if sd_model is None: # previous model load failed
|
||||||
current_checkpoint_info = None
|
current_checkpoint_info = None
|
||||||
@ -501,13 +540,12 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||||
del sd_model
|
del sd_model
|
||||||
checkpoints_loaded.clear()
|
|
||||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||||
return shared.sd_model
|
return model_data.sd_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
print("Failed to load checkpoint, restoring previous")
|
print("Failed to load checkpoint, restoring previous")
|
||||||
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
load_model_weights(sd_model, current_checkpoint_info, None, timer)
|
||||||
raise
|
raise
|
||||||
@ -526,17 +564,15 @@ def reload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
def unload_model_weights(sd_model=None, info=None):
|
def unload_model_weights(sd_model=None, info=None):
|
||||||
from modules import lowvram, devices, sd_hijack
|
from modules import devices, sd_hijack
|
||||||
timer = Timer()
|
timer = Timer()
|
||||||
|
|
||||||
if shared.sd_model:
|
if model_data.sd_model:
|
||||||
|
model_data.sd_model.to(devices.cpu)
|
||||||
# shared.sd_model.cond_stage_model.to(devices.cpu)
|
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
|
||||||
# shared.sd_model.first_stage_model.to(devices.cpu)
|
model_data.sd_model = None
|
||||||
shared.sd_model.to(devices.cpu)
|
|
||||||
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
|
|
||||||
shared.sd_model = None
|
|
||||||
sd_model = None
|
sd_model = None
|
||||||
gc.collect()
|
gc.collect()
|
||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
@ -544,4 +580,30 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_merging(sd_model, token_merging_ratio):
|
||||||
|
"""
|
||||||
|
Applies speed and memory optimizations from tomesd.
|
||||||
|
"""
|
||||||
|
|
||||||
|
current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)
|
||||||
|
|
||||||
|
if current_token_merging_ratio == token_merging_ratio:
|
||||||
|
return
|
||||||
|
|
||||||
|
if current_token_merging_ratio > 0:
|
||||||
|
tomesd.remove_patch(sd_model)
|
||||||
|
|
||||||
|
if token_merging_ratio > 0:
|
||||||
|
tomesd.apply_patch(
|
||||||
|
sd_model,
|
||||||
|
ratio=token_merging_ratio,
|
||||||
|
use_rand=False, # can cause issues with some samplers
|
||||||
|
merge_attn=True,
|
||||||
|
merge_crossattn=False,
|
||||||
|
merge_mlp=False
|
||||||
|
)
|
||||||
|
|
||||||
|
sd_model.applied_token_merged_ratio = token_merging_ratio
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -111,7 +110,7 @@ def find_checkpoint_config_near_filename(info):
|
|||||||
if info is None:
|
if info is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
config = os.path.splitext(info.filename)[0] + ".yaml"
|
config = f"{os.path.splitext(info.filename)[0]}.yaml"
|
||||||
if os.path.exists(config):
|
if os.path.exists(config):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
|
||||||
|
|
||||||
# imports for functions that previously were here and are used by other modules
|
# imports for functions that previously were here and are used by other modules
|
||||||
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image
|
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401
|
||||||
|
|
||||||
all_samplers = [
|
all_samplers = [
|
||||||
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
*sd_samplers_kdiffusion.samplers_data_k_diffusion,
|
||||||
@ -14,12 +14,18 @@ samplers_for_img2img = []
|
|||||||
samplers_map = {}
|
samplers_map = {}
|
||||||
|
|
||||||
|
|
||||||
def create_sampler(name, model):
|
def find_sampler_config(name):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
config = all_samplers_map.get(name, None)
|
config = all_samplers_map.get(name, None)
|
||||||
else:
|
else:
|
||||||
config = all_samplers[0]
|
config = all_samplers[0]
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler(name, model):
|
||||||
|
config = find_sampler_config(name)
|
||||||
|
|
||||||
assert config is not None, f'bad sampler name: {name}'
|
assert config is not None, f'bad sampler name: {name}'
|
||||||
|
|
||||||
sampler = config.constructor(model)
|
sampler = config.constructor(model)
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user