diff --git a/launch.py b/launch.py index 846c4c20..68e08114 100644 --- a/launch.py +++ b/launch.py @@ -280,9 +280,6 @@ def prepare_environment(): elif platform.system() == "Linux": run_pip(f"install {xformers_package}", "xformers") - if not is_installed("tomesd") and args.token_merging: - run_pip(f"install tomesd") - if not is_installed("pyngrok") and args.ngrok: run_pip("install pyngrok", "ngrok") diff --git a/modules/processing.py b/modules/processing.py index 6d9c6a8d..e115aadd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -29,6 +29,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +import tomesd # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -500,9 +501,28 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() + if opts.token_merging: + + if p.hr_second_pass_steps < 1 and not opts.token_merging_hr_only: + tomesd.apply_patch( + p.sd_model, + ratio=opts.token_merging_ratio, + max_downsample=opts.token_merging_maximum_down_sampling, + sx=opts.token_merging_stride_x, + sy=opts.token_merging_stride_y, + use_rand=opts.token_merging_random, + merge_attn=opts.token_merging_merge_attention, + merge_crossattn=opts.token_merging_merge_cross_attention, + merge_mlp=opts.token_merging_merge_mlp + ) + res = process_images_inner(p) finally: + # undo model optimizations made by tomesd + if opts.token_merging: + tomesd.remove_patch(p.sd_model) + # restore opts to original state if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): @@ -938,6 +958,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None devices.torch_gc() + # apply token merging optimizations from tomesd for high-res pass + # check if hr_only so we don't redundantly apply patch + if opts.token_merging and opts.token_merging_hr_only: + tomesd.apply_patch( + self.sd_model, + ratio=opts.token_merging_ratio, + max_downsample=opts.token_merging_maximum_down_sampling, + sx=opts.token_merging_stride_x, + sy=opts.token_merging_stride_y, + use_rand=opts.token_merging_random, + merge_attn=opts.token_merging_merge_attention, + merge_crossattn=opts.token_merging_merge_cross_attention, + merge_mlp=opts.token_merging_merge_mlp + ) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) return samples diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c05ec17..87c49b83 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -431,13 +431,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): sd_model = instantiate_from_config(sd_config.model) - if shared.cmd_opts.token_merging: - import tomesd - ratio = shared.cmd_opts.token_merging_ratio - - tomesd.apply_patch(sd_model, ratio=ratio) - print(f"Model accelerated using {(ratio * 100)}% token merging via tomesd.") - timer.record("token merging") except Exception as e: pass diff --git a/modules/shared.py b/modules/shared.py index 5fd0eecb..d7379e24 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -427,6 +427,50 @@ options_templates.update(options_section((None, "Hidden options"), { "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) +options_templates.update(options_section(('token_merging', 'Token Merging'), { + "token_merging": OptionInfo( + False, "Enable redundant token merging via tomesd. (currently incompatible with controlnet extension)", + gr.Checkbox + ), + "token_merging_ratio": OptionInfo( + 0.5, "Merging Ratio", + gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1} + ), + "token_merging_hr_only": OptionInfo( + True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.", + gr.Checkbox + ), + # More advanced/niche settings: + "token_merging_random": OptionInfo( + True, "Use random perturbations - Disabling might help with certain samplers", + gr.Checkbox + ), + "token_merging_merge_attention": OptionInfo( + True, "Merge attention", + gr.Checkbox + ), + "token_merging_merge_cross_attention": OptionInfo( + False, "Merge cross attention", + gr.Checkbox + ), + "token_merging_merge_mlp": OptionInfo( + False, "Merge mlp", + gr.Checkbox + ), + "token_merging_maximum_down_sampling": OptionInfo( + 1, "Maximum down sampling", + gr.Dropdown, lambda: {"choices": ["1", "2", "4", "8"]} + ), + "token_merging_stride_x": OptionInfo( + 2, "Stride - X", + gr.Slider, {"minimum": 2, "maximum": 8, "step": 2} + ), + "token_merging_stride_y": OptionInfo( + 2, "Stride - Y", + gr.Slider, {"minimum": 2, "maximum": 8, "step": 2} + ) +})) + options_templates.update() diff --git a/requirements_versions.txt b/requirements_versions.txt index df65431a..045230ab 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -28,3 +28,4 @@ torchsde==0.2.5 safetensors==0.3.0 httpcore<=0.15 fastapi==0.94.0 +tomesd>=0.1 \ No newline at end of file