diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 39fdca70..2ac44f6c 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -37,6 +37,11 @@ samplers = [ ] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] +sampler_extra_params = { + 'sample_euler':['s_churn','s_tmin','s_noise'], + 'sample_heun' :['s_churn','s_tmin','s_noise'], + 'sample_dpm_2':['s_churn','s_tmin','s_noise'], +} def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: @@ -224,6 +229,7 @@ class KDiffusionSampler: self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) + self.extra_params = sampler_extra_params.get(funcname,[]) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.sampler_noise_index = 0 @@ -269,7 +275,12 @@ class KDiffusionSampler: if self.sampler_noises is not None: k_diffusion.sampling.torch = TorchHijack(self) - return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) + extra_params_kwargs = {} + for val in self.extra_params: + if hasattr(opts,val): + extra_params_kwargs[val] = getattr(opts,val) + + return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): steps = steps or p.steps @@ -286,7 +297,12 @@ class KDiffusionSampler: if self.sampler_noises is not None: k_diffusion.sampling.torch = TorchHijack(self) - samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state) + extra_params_kwargs = {} + for val in self.extra_params: + if hasattr(opts,val): + extra_params_kwargs[val] = getattr(opts,val) + + samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) return samples