Allow different merge ratios to be used for each pass. Make toggle cmd flag work again. Remove ratio flag. Remove warning about controlnet being incompatible
This commit is contained in:
parent
c707b7df95
commit
5c8e53d5e9
@ -103,5 +103,4 @@ parser.add_argument("--no-hashing", action='store_true', help="disable sha256 ha
|
|||||||
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
|
||||||
|
|
||||||
# token merging / tomesd
|
# token merging / tomesd
|
||||||
parser.add_argument("--token-merging", action='store_true', help="Provides generation speedup by merging redundant tokens. (compatible with --xformers)", default=False)
|
parser.add_argument("--token-merging", action='store_true', help="Provides speed and memory improvements by merging redundant tokens. This has a more pronounced effect on higher resolutions.", default=False)
|
||||||
parser.add_argument("--token-merging-ratio", type=float, help="Adjusts ratio of merged to untouched tokens. Range: (0.0-1.0], Defaults to 0.5", default=0.5)
|
|
||||||
|
@ -501,26 +501,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
|
|||||||
if k == 'sd_vae':
|
if k == 'sd_vae':
|
||||||
sd_vae.reload_vae_weights()
|
sd_vae.reload_vae_weights()
|
||||||
|
|
||||||
if opts.token_merging and not opts.token_merging_hr_only:
|
if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only:
|
||||||
print("applying token merging to all passes")
|
print("\nApplying token merging\n")
|
||||||
tomesd.apply_patch(
|
sd_models.apply_token_merging(sd_model=p.sd_model, hr=False)
|
||||||
p.sd_model,
|
|
||||||
ratio=opts.token_merging_ratio,
|
|
||||||
max_downsample=opts.token_merging_maximum_down_sampling,
|
|
||||||
sx=opts.token_merging_stride_x,
|
|
||||||
sy=opts.token_merging_stride_y,
|
|
||||||
use_rand=opts.token_merging_random,
|
|
||||||
merge_attn=opts.token_merging_merge_attention,
|
|
||||||
merge_crossattn=opts.token_merging_merge_cross_attention,
|
|
||||||
merge_mlp=opts.token_merging_merge_mlp
|
|
||||||
)
|
|
||||||
|
|
||||||
res = process_images_inner(p)
|
res = process_images_inner(p)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# undo model optimizations made by tomesd
|
# undo model optimizations made by tomesd
|
||||||
if opts.token_merging:
|
if opts.token_merging or cmd_opts.token_merging:
|
||||||
print('removing token merging model optimizations')
|
print('\nRemoving token merging model optimizations\n')
|
||||||
tomesd.remove_patch(p.sd_model)
|
tomesd.remove_patch(p.sd_model)
|
||||||
|
|
||||||
# restore opts to original state
|
# restore opts to original state
|
||||||
@ -959,20 +949,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
|||||||
devices.torch_gc()
|
devices.torch_gc()
|
||||||
|
|
||||||
# apply token merging optimizations from tomesd for high-res pass
|
# apply token merging optimizations from tomesd for high-res pass
|
||||||
# check if hr_only so we don't redundantly apply patch
|
# check if hr_only so we are not redundantly patching
|
||||||
if opts.token_merging and opts.token_merging_hr_only:
|
if (cmd_opts.token_merging or opts.token_merging) and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio):
|
||||||
print("applying token merging for high-res pass")
|
# case where user wants to use separate merge ratios
|
||||||
tomesd.apply_patch(
|
if not opts.token_merging_hr_only:
|
||||||
self.sd_model,
|
# clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive)
|
||||||
ratio=opts.token_merging_ratio,
|
print('Temporarily reverting token merging optimizations in preparation for next pass')
|
||||||
max_downsample=opts.token_merging_maximum_down_sampling,
|
tomesd.remove_patch(self.sd_model)
|
||||||
sx=opts.token_merging_stride_x,
|
|
||||||
sy=opts.token_merging_stride_y,
|
print("\nApplying token merging for high-res pass\n")
|
||||||
use_rand=opts.token_merging_random,
|
sd_models.apply_token_merging(sd_model=self.sd_model, hr=True)
|
||||||
merge_attn=opts.token_merging_merge_attention,
|
|
||||||
merge_crossattn=opts.token_merging_merge_cross_attention,
|
|
||||||
merge_mlp=opts.token_merging_merge_mlp
|
|
||||||
)
|
|
||||||
|
|
||||||
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from modules import paths, shared, modelloader, devices, script_callbacks, sd_va
|
|||||||
from modules.paths import models_path
|
from modules.paths import models_path
|
||||||
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
from modules.sd_hijack_inpainting import do_inpainting_hijack
|
||||||
from modules.timer import Timer
|
from modules.timer import Timer
|
||||||
|
import tomesd
|
||||||
|
|
||||||
model_dir = "Stable-diffusion"
|
model_dir = "Stable-diffusion"
|
||||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||||
@ -545,4 +546,30 @@ def unload_model_weights(sd_model=None, info=None):
|
|||||||
|
|
||||||
print(f"Unloaded weights {timer.summary()}.")
|
print(f"Unloaded weights {timer.summary()}.")
|
||||||
|
|
||||||
return sd_model
|
return sd_model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_token_merging(sd_model, hr: bool):
|
||||||
|
"""
|
||||||
|
Applies speed and memory optimizations from tomesd.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hr (bool): True if called in the context of a high-res pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
ratio = shared.opts.token_merging_ratio
|
||||||
|
if hr:
|
||||||
|
ratio = shared.opts.token_merging_ratio_hr
|
||||||
|
print("effective hr pass merge ratio is "+str(ratio))
|
||||||
|
|
||||||
|
tomesd.apply_patch(
|
||||||
|
sd_model,
|
||||||
|
ratio=ratio,
|
||||||
|
max_downsample=shared.opts.token_merging_maximum_down_sampling,
|
||||||
|
sx=shared.opts.token_merging_stride_x,
|
||||||
|
sy=shared.opts.token_merging_stride_y,
|
||||||
|
use_rand=shared.opts.token_merging_random,
|
||||||
|
merge_attn=shared.opts.token_merging_merge_attention,
|
||||||
|
merge_crossattn=shared.opts.token_merging_merge_cross_attention,
|
||||||
|
merge_mlp=shared.opts.token_merging_merge_mlp
|
||||||
|
)
|
||||||
|
@ -429,7 +429,7 @@ options_templates.update(options_section((None, "Hidden options"), {
|
|||||||
|
|
||||||
options_templates.update(options_section(('token_merging', 'Token Merging'), {
|
options_templates.update(options_section(('token_merging', 'Token Merging'), {
|
||||||
"token_merging": OptionInfo(
|
"token_merging": OptionInfo(
|
||||||
False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)",
|
0.5, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.",
|
||||||
gr.Checkbox
|
gr.Checkbox
|
||||||
),
|
),
|
||||||
"token_merging_ratio": OptionInfo(
|
"token_merging_ratio": OptionInfo(
|
||||||
@ -440,6 +440,10 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), {
|
|||||||
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
|
True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.",
|
||||||
gr.Checkbox
|
gr.Checkbox
|
||||||
),
|
),
|
||||||
|
"token_merging_ratio_hr": OptionInfo(
|
||||||
|
0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.",
|
||||||
|
gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1}
|
||||||
|
),
|
||||||
# More advanced/niche settings:
|
# More advanced/niche settings:
|
||||||
"token_merging_random": OptionInfo(
|
"token_merging_random": OptionInfo(
|
||||||
True, "Use random perturbations - Disabling might help with certain samplers",
|
True, "Use random perturbations - Disabling might help with certain samplers",
|
||||||
|
Loading…
Reference in New Issue
Block a user