diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f4743928..1f093df2 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -32,6 +32,9 @@ class ExtraNetworkParams: else: self.positional.append(item) + def __eq__(self, other): + return self.items == other.items + class ExtraNetwork: def __init__(self, name): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..fae83788 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -171,6 +171,7 @@ class StableDiffusionProcessing: self.prompts = None self.negative_prompts = None + self.extra_network_data = None self.seeds = None self.subseeds = None @@ -311,7 +312,7 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def get_conds_with_caching(self, function, required_prompts, steps, cache): + def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -321,26 +322,24 @@ class StableDiffusionProcessing: have been used before. The second element is where the previously computed result is stored. """ - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: return cache[1] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) + cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) return cache[1] def setup_conds(self): sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data) def parse_extra_network_prompts(self): - self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) - - return extra_network_data + self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) class Processed: @@ -681,7 +680,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -702,11 +700,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(p.prompts) == 0: break - extra_network_data = p.parse_extra_network_prompts() + p.parse_extra_network_prompts() if not p.disable_extra_networks: with devices.autocast(): - extra_networks.activate(p, extra_network_data) + extra_networks.activate(p, p.extra_network_data) if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -828,8 +826,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks and extra_network_data: - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks and p.extra_network_data: + extra_networks.deactivate(p, p.extra_network_data) devices.torch_gc() @@ -1101,8 +1099,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): super().setup_conds() if self.enable_hr: - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data) def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts()