Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible

This commit is contained in:
papuSpartan 2023-04-04 02:26:44 -05:00
parent c707b7df95
commit 5c8e53d5e9
4 changed files with 49 additions and 33 deletions

View File

@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False) parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
# token merging / tomesd # token merging / tomesd
parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False) parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)

View File

@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() sd_vae.reload_vae_weights()
if opts.token_merging and not opts.token_merging_hr_only: if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
print("applying token merging to all passes") print("\nApplying token merging\n")
tomesd.apply_patch( sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
p.sd_model,
ratio=opts.token_merging_ratio,
max_downsample=opts.token_merging_maximum_down_sampling,
sx=opts.token_merging_stride_x,
sy=opts.token_merging_stride_y,
use_rand=opts.token_merging_random,
merge_attn=opts.token_merging_merge_attention,
merge_crossattn=opts.token_merging_merge_cross_attention,
merge_mlp=opts.token_merging_merge_mlp
)
res = process_images_inner(p) res = process_images_inner(p)
finally: finally:
# undo model optimizations made by tomesd # undo model optimizations made by tomesd
if opts.token_merging: if opts.token_merging or cmd_opts.token_merging:
print('removing token merging model optimizations') print('\nRemoving token merging model optimizations\n')
tomesd.remove_patch(p.sd_model) tomesd.remove_patch(p.sd_model)
# restore opts to original state # restore opts to original state
@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
devices.torch_gc() devices.torch_gc()
# apply token merging optimizations from tomesd for high-res pass # apply token merging optimizations from tomesd for high-res pass
# check if hr_only so we don't redundantly apply patch # check if hr_only so we are not redundantly patching
if opts.token_merging and opts.token_merging_hr_only: if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
print("applying token merging for high-res pass") # case where user wants to use separate merge ratios
tomesd.apply_patch( if not opts.token_merging_hr_only:
self.sd_model, # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
ratio=opts.token_merging_ratio, print('Temporarily reverting token merging optimizations in preparation for next pass')
max_downsample=opts.token_merging_maximum_down_sampling, tomesd.remove_patch(self.sd_model)
sx=opts.token_merging_stride_x,
sy=opts.token_merging_stride_y, print("\nApplying token merging for high-res pass\n")
use_rand=opts.token_merging_random, sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
merge_attn=opts.token_merging_merge_attention,
merge_crossattn=opts.token_merging_merge_cross_attention,
merge_mlp=opts.token_merging_merge_mlp
)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

View File

@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@ -545,4 +546,30 @@ def unload_model_weights(sd_model=None, info=None):
print(f"Unloaded weights {timer.summary()}.") print(f"Unloaded weights {timer.summary()}.")
return sd_model return sd_model
def apply_token_merging(sd_model, hr: bool):
"""
Applies speed and memory optimizations from tomesd.
Args:
hr (bool): True if called in the context of a high-res pass
"""
ratio = shared.opts.token_merging_ratio
if hr:
ratio = shared.opts.token_merging_ratio_hr
print("effective hr pass merge ratio is "+str(ratio))
tomesd.apply_patch(
sd_model,
ratio=ratio,
max_downsample=shared.opts.token_merging_maximum_down_sampling,
sx=shared.opts.token_merging_stride_x,
sy=shared.opts.token_merging_stride_y,
use_rand=shared.opts.token_merging_random,
merge_attn=shared.opts.token_merging_merge_attention,
merge_crossattn=shared.opts.token_merging_merge_cross_attention,
merge_mlp=shared.opts.token_merging_merge_mlp
)

View File

@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
options_templates.update(options_section(('token_merging', 'Token Merging'), { options_templates.update(options_section(('token_merging', 'Token Merging'), {
"token_merging": OptionInfo( "token_merging": OptionInfo(
False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)", 0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
gr.Checkbox gr.Checkbox
), ),
"token_merging_ratio": OptionInfo( "token_merging_ratio": OptionInfo(
@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.", True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
gr.Checkbox gr.Checkbox
), ),
"token_merging_ratio_hr": OptionInfo(
0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
),
# More advanced/niche settings: # More advanced/niche settings:
"token_merging_random": OptionInfo( "token_merging_random": OptionInfo(
True, "Use random perturbations - Disabling might help with certain samplers", True, "Use random perturbations - Disabling might help with certain samplers",