diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 6ea58d61..a206ea59 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -24,12 +24,22 @@ class ImageSaveParams: """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" +class CGFDenoiserParams: + def __init__(self, x_in, image_cond_in, sigma_in, sampling_step, total_sampling_steps): + self.x_in = x_in + self.image_cond_in = image_cond_in + self.sigma_in = sigma_in + self.sampling_step = sampling_step + self.total_sampling_steps = total_sampling_steps + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] callbacks_before_image_saved = [] callbacks_image_saved = [] +callbacks_cfg_denoiser = [] def clear_callbacks(): @@ -84,6 +94,14 @@ def image_saved_callback(params: ImageSaveParams): report_exception(c, 'image_saved_callback') +def cfg_denoiser_callback(params: CGFDenoiserParams): + for c in callbacks_cfg_denoiser: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_denoiser_callback') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -130,3 +148,12 @@ def on_image_saved(callback): - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ add_callback(callbacks_image_saved, callback) + + +def on_cfg_denoiser(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. + The callback is called with one argument: + - params: CGFDenoiserParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callbacks_cfg_denoiser, callback) +