From ac83627a31daac06f4d48b0e7db223ef807fe8e5 Mon Sep 17 00:00:00 2001 From: papuSpartan <30642826+papuSpartan@users.noreply.github.com> Date: Sat, 13 May 2023 10:23:42 -0500 Subject: [PATCH] heavily simplify --- modules/generation_parameters_copypaste.py | 36 ------------------- modules/processing.py | 35 ++++++++---------- modules/sd_models.py | 11 +++--- modules/shared.py | 42 +++------------------- 4 files changed, 23 insertions(+), 101 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fb56254f..a0a98bbc 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -282,33 +282,6 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model res["Hires resize-1"] = 0 res["Hires resize-2"] = 0 - # Infer additional override settings for token merging - token_merging_ratio = res.get("Token merging ratio", None) - token_merging_ratio_hr = res.get("Token merging ratio hr", None) - - if token_merging_ratio is not None or token_merging_ratio_hr is not None: - res["Token merging"] = 'True' - - if token_merging_ratio is None: - res["Token merging hr only"] = 'True' - else: - res["Token merging hr only"] = 'False' - - if res.get("Token merging random", None) is None: - res["Token merging random"] = 'False' - if res.get("Token merging merge attention", None) is None: - res["Token merging merge attention"] = 'True' - if res.get("Token merging merge cross attention", None) is None: - res["Token merging merge cross attention"] = 'False' - if res.get("Token merging merge mlp", None) is None: - res["Token merging merge mlp"] = 'False' - if res.get("Token merging stride x", None) is None: - res["Token merging stride x"] = '2' - if res.get("Token merging stride y", None) is None: - res["Token merging stride y"] = '2' - if res.get("Token merging maximum down sampling", None) is None: - res["Token merging maximum down sampling"] = '1' - restore_old_hires_fix_params(res) # Missing RNG means the default was set, which is GPU RNG @@ -335,17 +308,8 @@ infotext_to_setting_name_mapping = [ ('UniPC skip type', 'uni_pc_skip_type'), ('UniPC order', 'uni_pc_order'), ('UniPC lower order final', 'uni_pc_lower_order_final'), - ('Token merging', 'token_merging'), ('Token merging ratio', 'token_merging_ratio'), - ('Token merging hr only', 'token_merging_hr_only'), ('Token merging ratio hr', 'token_merging_ratio_hr'), - ('Token merging random', 'token_merging_random'), - ('Token merging merge attention', 'token_merging_merge_attention'), - ('Token merging merge cross attention', 'token_merging_merge_cross_attention'), - ('Token merging merge mlp', 'token_merging_merge_mlp'), - ('Token merging maximum down sampling', 'token_merging_maximum_down_sampling'), - ('Token merging stride x', 'token_merging_stride_x'), - ('Token merging stride y', 'token_merging_stride_y'), ('RNG', 'randn_source'), ('NGMS', 's_min_uncond') ] diff --git a/modules/processing.py b/modules/processing.py index 6828e898..32ff61e9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -34,7 +34,7 @@ import tomesd # add a logger for the processing module logger = logging.getLogger(__name__) # manually set output level here since there is no option to do so yet through launch options -# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s') +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s') # some of those options should not be changed at all because they would break the model, so I removed them from options. @@ -496,15 +496,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, - "Token merging ratio": None if not opts.token_merging or opts.token_merging_hr_only else opts.token_merging_ratio, - "Token merging ratio hr": None if not opts.token_merging else opts.token_merging_ratio_hr, - "Token merging random": None if opts.token_merging_random is False else opts.token_merging_random, - "Token merging merge attention": None if opts.token_merging_merge_attention is True else opts.token_merging_merge_attention, - "Token merging merge cross attention": None if opts.token_merging_merge_cross_attention is False else opts.token_merging_merge_cross_attention, - "Token merging merge mlp": None if opts.token_merging_merge_mlp is False else opts.token_merging_merge_mlp, - "Token merging stride x": None if opts.token_merging_stride_x == 2 else opts.token_merging_stride_x, - "Token merging stride y": None if opts.token_merging_stride_y == 2 else opts.token_merging_stride_y, - "Token merging maximum down sampling": None if opts.token_merging_maximum_down_sampling == 1 else opts.token_merging_maximum_down_sampling, + "Token merging ratio": None if opts.token_merging_ratio == 0 else opts.token_merging_ratio, + "Token merging ratio hr": None if not p.enable_hr or opts.token_merging_ratio_hr == 0 else opts.token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), "RNG": opts.randn_source if opts.randn_source != "GPU" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, @@ -538,15 +531,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_vae': sd_vae.reload_vae_weights() - if opts.token_merging and not opts.token_merging_hr_only: + if opts.token_merging_ratio > 0: sd_models.apply_token_merging(sd_model=p.sd_model, hr=False) - logger.debug('Token merging applied') + logger.debug(f"Token merging applied to first pass. Ratio: '{opts.token_merging_ratio}'") res = process_images_inner(p) finally: # undo model optimizations made by tomesd - if opts.token_merging: + if opts.token_merging_ratio > 0: tomesd.remove_patch(p.sd_model) logger.debug('Token merging model optimizations removed') @@ -1003,19 +996,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): devices.torch_gc() # apply token merging optimizations from tomesd for high-res pass - # check if hr_only so we are not redundantly patching - if opts.token_merging and (opts.token_merging_hr_only or opts.token_merging_ratio_hr != opts.token_merging_ratio): - # case where user wants to use separate merge ratios - if not opts.token_merging_hr_only: - # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive) + if opts.token_merging_ratio_hr > 0: + # in case the user has used separate merge ratios + if opts.token_merging_ratio > 0: tomesd.remove_patch(self.sd_model) - logger.debug('Temporarily removed token merging optimizations in preparation for next pass') + logger.debug('Adjusting token merging ratio for high-res pass') sd_models.apply_token_merging(sd_model=self.sd_model, hr=True) - logger.debug('Applied token merging for high-res pass') + logger.debug(f"Applied token merging for high-res pass. Ratio: '{opts.token_merging_ratio_hr}'") samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) + if opts.token_merging_ratio_hr > 0 or opts.token_merging_ratio > 0: + tomesd.remove_patch(self.sd_model) + logger.debug('Removed token merging optimizations from model') + self.is_hr_pass = False return samples diff --git a/modules/sd_models.py b/modules/sd_models.py index 4787193c..4c9a0a1f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -596,11 +596,8 @@ def apply_token_merging(sd_model, hr: bool): tomesd.apply_patch( sd_model, ratio=ratio, - max_downsample=shared.opts.token_merging_maximum_down_sampling, - sx=shared.opts.token_merging_stride_x, - sy=shared.opts.token_merging_stride_y, - use_rand=shared.opts.token_merging_random, - merge_attn=shared.opts.token_merging_merge_attention, - merge_crossattn=shared.opts.token_merging_merge_cross_attention, - merge_mlp=shared.opts.token_merging_merge_mlp + use_rand=False, # can cause issues with some samplers + merge_attn=True, + merge_crossattn=False, + merge_mlp=False ) diff --git a/modules/shared.py b/modules/shared.py index 4b346585..0d96c14c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -459,47 +459,13 @@ options_templates.update(options_section((None, "Hidden options"), { })) options_templates.update(options_section(('token_merging', 'Token Merging'), { - "token_merging": OptionInfo( - False, "Enable redundant token merging via tomesd. This can provide significant speed and memory improvements.", - gr.Checkbox + "token_merging_ratio_hr": OptionInfo( + 0, "Merging Ratio (high-res pass)", + gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1} ), "token_merging_ratio": OptionInfo( - 0.5, "Merging Ratio", + 0, "Merging Ratio", gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1} - ), - "token_merging_hr_only": OptionInfo( - True, "Apply only to high-res fix pass. Disabling can yield a ~20-35% speedup on contemporary resolutions.", - gr.Checkbox - ), - "token_merging_ratio_hr": OptionInfo( - 0.5, "Merging Ratio (high-res pass) - If 'Apply only to high-res' is enabled, this will always be the ratio used.", - gr.Slider, {"minimum": 0, "maximum": 0.9, "step": 0.1} - ), - # More advanced/niche settings: - "token_merging_random": OptionInfo( - False, "Use random perturbations - Can improve outputs for certain samplers. For others, it may cause visual artifacting.", - gr.Checkbox - ), - "token_merging_merge_attention": OptionInfo( - True, "Merge attention", - gr.Checkbox - ), - "token_merging_merge_cross_attention": OptionInfo( - False, "Merge cross attention", - gr.Checkbox - ), - "token_merging_merge_mlp": OptionInfo( - False, "Merge mlp", - gr.Checkbox - ), - "token_merging_maximum_down_sampling": OptionInfo(1, "Maximum down sampling", gr.Radio, lambda: {"choices": [1, 2, 4, 8]}), - "token_merging_stride_x": OptionInfo( - 2, "Stride - X", - gr.Slider, {"minimum": 2, "maximum": 8, "step": 2} - ), - "token_merging_stride_y": OptionInfo( - 2, "Stride - Y", - gr.Slider, {"minimum": 2, "maximum": 8, "step": 2} ) }))