remove excess condition

This commit is contained in:
papuSpartan 2023-04-01 23:47:10 -05:00
parent a609bd56b4
commit c707b7df95

View File

@ -501,26 +501,26 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae': if k == 'sd_vae':
sd_vae.reload_vae_weights() sd_vae.reload_vae_weights()
if opts.token_merging: if opts.token_merging and not opts.token_merging_hr_only:
print("applying token merging to all passes")
if p.hr_second_pass_steps < 1 and not opts.token_merging_hr_only: tomesd.apply_patch(
tomesd.apply_patch( p.sd_model,
p.sd_model, ratio=opts.token_merging_ratio,
ratio=opts.token_merging_ratio, max_downsample=opts.token_merging_maximum_down_sampling,
max_downsample=opts.token_merging_maximum_down_sampling, sx=opts.token_merging_stride_x,
sx=opts.token_merging_stride_x, sy=opts.token_merging_stride_y,
sy=opts.token_merging_stride_y, use_rand=opts.token_merging_random,
use_rand=opts.token_merging_random, merge_attn=opts.token_merging_merge_attention,
merge_attn=opts.token_merging_merge_attention, merge_crossattn=opts.token_merging_merge_cross_attention,
merge_crossattn=opts.token_merging_merge_cross_attention, merge_mlp=opts.token_merging_merge_mlp
merge_mlp=opts.token_merging_merge_mlp )
)
res = process_images_inner(p) res = process_images_inner(p)
finally: finally:
# undo model optimizations made by tomesd # undo model optimizations made by tomesd
if opts.token_merging: if opts.token_merging:
print('removing token merging model optimizations')
tomesd.remove_patch(p.sd_model) tomesd.remove_patch(p.sd_model)
# restore opts to original state # restore opts to original state
@ -961,6 +961,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
# apply token merging optimizations from tomesd for high-res pass # apply token merging optimizations from tomesd for high-res pass
# check if hr_only so we don't redundantly apply patch # check if hr_only so we don't redundantly apply patch
if opts.token_merging and opts.token_merging_hr_only: if opts.token_merging and opts.token_merging_hr_only:
print("applying token merging for high-res pass")
tomesd.apply_patch( tomesd.apply_patch(
self.sd_model, self.sd_model,
ratio=opts.token_merging_ratio, ratio=opts.token_merging_ratio,