Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible

This commit is contained in:
papuSpartan 2023-04-04 02:26:44 -05:00
parent c707b7df95
commit 5c8e53d5e9
4 changed files with 49 additions and 33 deletions

View File

@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
# token merging / tomesd
parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)

View File

@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
if opts.token_merging and not opts.token_merging_hr_only:
print("applying token merging to all passes")
tomesd.apply_patch(
p.sd_model,
ratio=opts.token_merging_ratio,
max_downsample=opts.token_merging_maximum_down_sampling,
sx=opts.token_merging_stride_x,
sy=opts.token_merging_stride_y,
use_rand=opts.token_merging_random,
merge_attn=opts.token_merging_merge_attention,
merge_crossattn=opts.token_merging_merge_cross_attention,
merge_mlp=opts.token_merging_merge_mlp
)
if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
print("\nApplying token merging\n")
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
res = process_images_inner(p)
finally:
# undo model optimizations made by tomesd
if opts.token_merging:
print('removing token merging model optimizations')
if opts.token_merging or cmd_opts.token_merging:
print('\nRemoving token merging model optimizations\n')
tomesd.remove_patch(p.sd_model)
# restore opts to original state
@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass
# check if hr_only so we don't redundantly apply patch
if opts.token_merging and opts.token_merging_hr_only:
print("applying token merging for high-res pass")
tomesd.apply_patch(
self.sd_model,
ratio=opts.token_merging_ratio,
max_downsample=opts.token_merging_maximum_down_sampling,
sx=opts.token_merging_stride_x,
sy=opts.token_merging_stride_y,
use_rand=opts.token_merging_random,
merge_attn=opts.token_merging_merge_attention,
merge_crossattn=opts.token_merging_merge_cross_attention,
merge_mlp=opts.token_merging_merge_mlp
)
# check if hr_only so we are not redundantly patching
if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
# case where user wants to use separate merge ratios
if not opts.token_merging_hr_only:
# clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
print('Temporarily reverting token merging optimizations in preparation for next pass')
tomesd.remove_patch(self.sd_model)
print("\nApplying token merging for high-res pass\n")
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

View File

@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@ -546,3 +547,29 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.")
return sd_model
def apply_token_merging(sd_model, hr: bool):
"""
Applies speed and memory optimizations from tomesd.
Args:
hr (bool): True if called in the context of a high-res pass
"""
ratio = shared.opts.token_merging_ratio
if hr:
ratio = shared.opts.token_merging_ratio_hr
print("effective hr pass merge ratio is "+str(ratio))
tomesd.apply_patch(
sd_model,
ratio=ratio,
max_downsample=shared.opts.token_merging_maximum_down_sampling,
sx=shared.opts.token_merging_stride_x,
sy=shared.opts.token_merging_stride_y,
use_rand=shared.opts.token_merging_random,
merge_attn=shared.opts.token_merging_merge_attention,
merge_crossattn=shared.opts.token_merging_merge_cross_attention,
merge_mlp=shared.opts.token_merging_merge_mlp
)

View File

@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
options_templates.update(options_section(('token_merging', 'Token Merging'), {
"token_merging": OptionInfo(
False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
gr.Checkbox
),
"token_merging_ratio": OptionInfo(
@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
gr.Checkbox
),
"token_merging_ratio_hr": OptionInfo(
0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
),
# More advanced/niche settings:
"token_merging_random": OptionInfo(
True, "Use random perturbations - Disabling might help with certain samplers",