diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 6df76858..a7ede534 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -282,6 +282,32 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model res["Hires resize-1"] = 0 res["Hires resize-2"] = 0 + # Infer additional override settings for token merging + print("inferring settings for tomesd") + token_merging_ratio = res.get("Token merging ratio", None) + token_merging_ratio_hr = res.get("Token merging ratio hr", None) + + if token_merging_ratio is not None or token_merging_ratio_hr is not None: + res["Token merging"] = 'True' + + if token_merging_ratio is None: + res["Token merging hr only"] = 'True' + else: + res["Token merging hr only"] = 'False' + + if res.get("Token merging random", None) is None: + res["Token merging random"] = 'False' + if res.get("Token merging merge attention", None) is None: + res["Token merging merge attention"] = 'True' + if res.get("Token merging merge cross attention", None) is None: + res["Token merging merge cross attention"] = 'False' + if res.get("Token merging merge mlp", None) is None: + res["Token merging merge mlp"] = 'False' + if res.get("Token merging stride x", None) is None: + res["Token merging stride x"] = '2' + if res.get("Token merging stride y", None) is None: + res["Token merging stride y"] = '2' + restore_old_hires_fix_params(res) return res @@ -304,6 +330,17 @@ infotext_to_setting_name_mapping = [ ('UniPC skip type', 'uni_pc_skip_type'), ('UniPC order', 'uni_pc_order'), ('UniPC lower order final', 'uni_pc_lower_order_final'), + ('Token merging', 'token_merging'), + ('Token merging ratio', 'token_merging_ratio'), + ('Token merging hr only', 'token_merging_hr_only'), + ('Token merging ratio hr', 'token_merging_ratio_hr'), + ('Token merging random', 'token_merging_random'), + ('Token merging merge attention', 'token_merging_merge_attention'), + ('Token merging merge cross attention', 'token_merging_merge_cross_attention'), + ('Token merging merge mlp', 'token_merging_merge_mlp'), + ('Token merging maximum downsampling', 'token_merging_maximum_downsampling'), + ('Token merging stride x', 'token_merging_stride_x'), + ('Token merging stride y', 'token_merging_stride_y') ] diff --git a/modules/processing.py b/modules/processing.py index 670a7a28..95058d0b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -31,6 +31,12 @@ from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType import tomesd +# add a logger for the processing module +logger = logging.getLogger(__name__) +# manually set output level here since there is no option to do so yet through launch options +# logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(levelname)s %(name)s %(message)s') + + # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 @@ -477,6 +483,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, + "Token merging ratio": None if not (opts.token_merging or cmd_opts.token_merging) or opts.token_merging_hr_only else opts.token_merging_ratio, + "Token merging ratio hr": None if not (opts.token_merging or cmd_opts.token_merging) else opts.token_merging_ratio_hr, + "Token merging random": None if opts.token_merging_random is False else opts.token_merging_random, + "Token merging merge attention": None if opts.token_merging_merge_attention is True else opts.token_merging_merge_attention, + "Token merging merge cross attention": None if opts.token_merging_merge_cross_attention is False else opts.token_merging_merge_cross_attention, + "Token merging merge mlp": None if opts.token_merging_merge_mlp is False else opts.token_merging_merge_mlp, + "Token merging stride x": None if opts.token_merging_stride_x == 2 else opts.token_merging_stride_x, + "Token merging stride y": None if opts.token_merging_stride_y == 2 else opts.token_merging_stride_y } generation_params.update(p.extra_generation_params) @@ -502,16 +516,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed: sd_vae.reload_vae_weights() if (opts.token_merging or cmd_opts.token_merging) and not opts.token_merging_hr_only: - print("\nApplying token merging\n") sd_models.apply_token_merging(sd_model=p.sd_model, hr=False) + logger.debug('Token merging applied') res = process_images_inner(p) finally: # undo model optimizations made by tomesd if opts.token_merging or cmd_opts.token_merging: - print('\nRemoving token merging model optimizations\n') tomesd.remove_patch(p.sd_model) + logger.debug('Token merging model optimizations removed') # restore opts to original state if p.override_settings_restore_afterwards: @@ -954,11 +968,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # case where user wants to use separate merge ratios if not opts.token_merging_hr_only: # clean patch done by first pass. (clobbering the first patch might be fine? this might be excessive) - print('Temporarily reverting token merging optimizations in preparation for next pass') tomesd.remove_patch(self.sd_model) + logger.debug('Temporarily removed token merging optimizations in preparation for next pass') - print("\nApplying token merging for high-res pass\n") sd_models.apply_token_merging(sd_model=self.sd_model, hr=True) + logger.debug('Applied token merging for high-res pass') samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) diff --git a/modules/shared.py b/modules/shared.py index 568acdc4..d9db7317 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -446,7 +446,7 @@ options_templates.update(options_section(('token_merging', 'Token Merging'), { ), # More advanced/niche settings: "token_merging_random": OptionInfo( - True, "Use random perturbations - Disabling might help with certain samplers", + False, "Use random perturbations - Can improve outputs for certain samplers. For others, it may cause visaul artifacting.", gr.Checkbox ), "token_merging_merge_attention": OptionInfo( diff --git a/requirements_versions.txt b/requirements_versions.txt index 03522715..f972f975 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -28,4 +28,4 @@ torchsde==0.2.5 safetensors==0.3.0 httpcore<=0.15 fastapi==0.94.0 -tomesd>=0.1.1 \ No newline at end of file +tomesd>=0.1.2 \ No newline at end of file