From 125319988984987801dc4b4ab1e5ed36e9b211c5 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 03:30:20 -0800 Subject: [PATCH 1/8] Working UniPC (for batch size 1) --- javascript/hints.js | 1 + modules/models/diffusion/uni_pc/__init__.py | 1 + modules/models/diffusion/uni_pc/sampler.py | 85 ++ modules/models/diffusion/uni_pc/uni_pc.py | 858 ++++++++++++++++++++ modules/processing.py | 2 +- modules/sd_samplers_compvis.py | 35 +- test/basic_features/txt2img_test.py | 2 + 7 files changed, 978 insertions(+), 6 deletions(-) create mode 100644 modules/models/diffusion/uni_pc/__init__.py create mode 100644 modules/models/diffusion/uni_pc/sampler.py create mode 100644 modules/models/diffusion/uni_pc/uni_pc.py diff --git a/javascript/hints.js b/javascript/hints.js index 9aa82f24..0a0620e3 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -6,6 +6,7 @@ titles = { "GFPGAN": "Restore low quality faces using GFPGAN neural network", "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting", + "UniPC": "Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", "Batch count": "How many batches of images to create", diff --git a/modules/models/diffusion/uni_pc/__init__.py b/modules/models/diffusion/uni_pc/__init__.py new file mode 100644 index 00000000..e1265e3f --- /dev/null +++ b/modules/models/diffusion/uni_pc/__init__.py @@ -0,0 +1 @@ +from .sampler import UniPCSampler diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py new file mode 100644 index 00000000..7cccd8a2 --- /dev/null +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -0,0 +1,85 @@ +"""SAMPLING ONLY.""" + +import torch + +from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC + +class UniPCSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.before_sample = None + self.after_sample = None + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def set_hooks(self, before, after): + self.before_sample = before + self.after_sample = after + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + #condition=conditioning, + #unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + + return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py new file mode 100644 index 00000000..ec6b37da --- /dev/null +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -0,0 +1,858 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + #condition=None, + #unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, None, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input, condition): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous, condition, unconditional_condition): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input, condition) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([ + unconditional_condition[k][i], + condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([ + unconditional_condition[k], + condition[k]]) + elif isinstance(condition, list): + c_in = list() + assert isinstance(unconditional_condition, list) + for i in range(len(condition)): + c_in.append(torch.cat([unconditional_condition[i], condition[i]])) + else: + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class UniPC: + def __init__( + self, + model_fn, + noise_schedule, + predict_x0=True, + thresholding=False, + max_val=1., + variant='bh1', + condition=None, + unconditional_condition=None, + before_sample=None, + after_sample=None + ): + """Construct a UniPC. + + We support both data_prediction and noise_prediction. + """ + self.model_fn_ = model_fn + self.noise_schedule = noise_schedule + self.variant = variant + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + self.condition = condition + self.unconditional_condition = unconditional_condition + self.before_sample = before_sample + self.after_sample = after_sample + + def dynamic_thresholding_fn(self, x0, t=None): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model(self, x, t): + cond = self.condition + uncond = self.unconditional_condition + if self.before_sample is not None: + x, t, cond, uncond = self.before_sample(x, t, cond, uncond) + res = self.model_fn_(x, t, cond, uncond) + if self.after_sample is not None: + x, t, cond, uncond, res = self.after_sample(x, t, cond, uncond, res) + + if isinstance(res, tuple): + # (None, pred_x0) + res = res[1] + + return res + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + from pprint import pp + print("X:") + pp(x) + print("sigma_t:") + pp(sigma_t) + print("noise:") + pp(noise) + print("alpha_t:") + pp(alpha_t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = steps + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs): + if len(t.shape) == 0: + t = t.view(-1) + if 'bh' in self.variant: + return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + else: + assert self.variant == 'vary_coeff' + return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) + + def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + ns = self.noise_schedule + assert order <= len(model_prev_list) + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_t = ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = (lambda_prev_i - lambda_prev_0) / h + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + K = len(rks) + # build C matrix + C = [] + + col = torch.ones_like(rks) + for k in range(1, K + 1): + C.append(col) + col = col * rks / (k + 1) + C = torch.stack(C, dim=1) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + C_inv_p = torch.linalg.inv(C[:-1, :-1]) + A_p = C_inv_p + + if use_corrector: + print('using corrector') + C_inv = torch.linalg.inv(C) + A_c = C_inv + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) + h_phi_ks = [] + factorial_k = 1 + h_phi_k = h_phi_1 + for k in range(1, K + 2): + h_phi_ks.append(h_phi_k) + h_phi_k = h_phi_k / hh - 1 / factorial_k + factorial_k *= (k + 1) + + model_t = None + if self.predict_x0: + x_t_ = ( + sigma_t / sigma_prev_0 * x + - alpha_t * h_phi_1 * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + else: + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + x_t_ = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * h_phi_1) * model_prev_0 + ) + # now predictor + x_t = x_t_ + if len(D1s) > 0: + # compute the residuals for predictor + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k]) + # now corrector + if use_corrector: + model_t = self.model_fn(x_t, t) + D1_t = (model_t - model_prev_0) + x_t = x_t_ + k = 0 + for k in range(K - 1): + x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1]) + x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1]) + return x_t, model_t + + def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): + print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + ns = self.noise_schedule + assert order <= len(model_prev_list) + dims = x.dim() + + # first compute rks + t_prev_0 = t_prev_list[-1] + lambda_prev_0 = ns.marginal_lambda(t_prev_0) + lambda_t = ns.marginal_lambda(t) + model_prev_0 = model_prev_list[-1] + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + alpha_t = torch.exp(log_alpha_t) + + h = lambda_t - lambda_prev_0 + + rks = [] + D1s = [] + for i in range(1, order): + t_prev_i = t_prev_list[-(i + 1)] + model_prev_i = model_prev_list[-(i + 1)] + lambda_prev_i = ns.marginal_lambda(t_prev_i) + rk = ((lambda_prev_i - lambda_prev_0) / h)[0] + rks.append(rk) + D1s.append((model_prev_i - model_prev_0) / rk) + + rks.append(1.) + rks = torch.tensor(rks, device=x.device) + + R = [] + b = [] + + hh = -h[0] if self.predict_x0 else h[0] + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.variant == 'bh1': + B_h = hh + elif self.variant == 'bh2': + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= (i + 1) + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=x.device) + + # now predictor + use_predictor = len(D1s) > 0 and x_t is None + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + if x_t is None: + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], device=b.device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) + else: + D1s = None + + if use_corrector: + print('using corrector') + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], device=b.device) + else: + rhos_c = torch.linalg.solve(R, b) + + model_t = None + if self.predict_x0: + x_t_ = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0 + ) + + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x + - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0 + ) + if x_t is None: + if use_predictor: + pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res + + if use_corrector: + model_t = self.model_fn(x_t, t) + if D1s is not None: + corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = (model_t - model_prev_0) + x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + return x_t, model_t + + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, corrector=False, + ): + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list.append(model_x) + t_prev_list.append(vec_t) + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final: + step_order = min(order, steps + 1 - step) + else: + step_order = order + print('this step order:', step_order) + if step == steps: + print('do not run corrector at the last step') + use_corrector = False + else: + use_corrector = True + x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + if model_x is None: + model_x = self.model_fn(x, vec_t) + model_prev_list[-1] = model_x + else: + raise NotImplementedError() + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] diff --git a/modules/processing.py b/modules/processing.py index e1b53ac0..11e726df 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index d03131cd..86fa1c5b 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -7,19 +7,27 @@ import torch from modules.shared import state from modules import sd_samplers_common, prompt_parser, shared +import modules.models.diffusion.uni_pc samplers_data_compvis = [ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), + sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {}), ] class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) + self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim + self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) + self.orig_p_sample_ddim = None + if self.is_plms: + self.orig_p_sample_ddim = self.sampler.p_sample_plms + elif self.is_ddim: + self.orig_p_sample_ddim = self.sampler.p_sample_ddim self.mask = None self.nmask = None self.init_latent = None @@ -45,6 +53,15 @@ class VanillaStableDiffusionSampler: return self.last_latent def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): + x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) + + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + + x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) + + return res + + def before_sample(self, x, ts, cond, unconditional_conditioning): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -76,7 +93,7 @@ class VanillaStableDiffusionSampler: if self.mask is not None: img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x_dec = img_orig * self.mask + self.nmask * x_dec + x = img_orig * self.mask + self.nmask * x # Wrap the image conditioning back up since the DDIM code can accept the dict directly. # Note that they need to be lists because it just concatenates them later. @@ -84,7 +101,13 @@ class VanillaStableDiffusionSampler: cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) + return x, ts, cond, unconditional_conditioning + + def after_sample(self, x, ts, cond, uncond, res): + if self.is_unipc: + # unipc model_fn returns (pred_x0) + # p_sample_ddim returns (x_prev, pred_x0) + res = (None, res[0]) if self.mask is not None: self.last_latent = self.init_latent * self.mask + self.nmask * res[1] @@ -97,7 +120,7 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() - return res + return x, ts, cond, uncond, res def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim @@ -107,12 +130,14 @@ class VanillaStableDiffusionSampler: for fieldname in ['p_sample_ddim', 'p_sample_plms']: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) + if self.is_unipc: + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 5aa43a44..cb525fbb 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -66,6 +66,8 @@ class TestTxt2ImgWorking(unittest.TestCase): self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) self.simple_txt2img["sampler_index"] = "DDIM" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + self.simple_txt2img["sampler_index"] = "UniPC" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) def test_txt2img_multiple_batches_performed(self): self.simple_txt2img["n_iter"] = 2 From 21880eb9e57b884635a07d2360831b4186afddf4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 04:47:08 -0800 Subject: [PATCH 2/8] Fix logspam and live previews --- modules/models/diffusion/uni_pc/sampler.py | 20 ++++++++++---- modules/models/diffusion/uni_pc/uni_pc.py | 32 ++++++++++------------ modules/sd_samplers_compvis.py | 20 ++++++++------ 3 files changed, 41 insertions(+), 31 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 7cccd8a2..219e9862 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -19,9 +19,10 @@ class UniPCSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def set_hooks(self, before, after): - self.before_sample = before - self.after_sample = after + def set_hooks(self, before_sample, after_sample, after_update): + self.before_sample = before_sample + self.after_sample = after_sample + self.after_update = after_update @torch.no_grad() def sample(self, @@ -50,9 +51,17 @@ class UniPCSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") @@ -60,6 +69,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) + print(f'Data shape for UniPC sampling is {size}, eta {eta}') device = self.model.betas.device if x_T is None: @@ -79,7 +89,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index ec6b37da..31ee81a6 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -378,7 +378,8 @@ class UniPC: condition=None, unconditional_condition=None, before_sample=None, - after_sample=None + after_sample=None, + after_update=None ): """Construct a UniPC. @@ -394,6 +395,7 @@ class UniPC: self.unconditional_condition = unconditional_condition self.before_sample = before_sample self.after_sample = after_sample + self.after_update = after_update def dynamic_thresholding_fn(self, x0, t=None): """ @@ -434,15 +436,6 @@ class UniPC: noise = self.noise_prediction_fn(x, t) dims = x.dim() alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - from pprint import pp - print("X:") - pp(x) - print("sigma_t:") - pp(sigma_t) - print("noise:") - pp(noise) - print("alpha_t:") - pp(alpha_t) x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) if self.thresholding: p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. @@ -524,7 +517,7 @@ class UniPC: return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') ns = self.noise_schedule assert order <= len(model_prev_list) @@ -568,7 +561,7 @@ class UniPC: A_p = C_inv_p if use_corrector: - print('using corrector') + #print('using corrector') C_inv = torch.linalg.inv(C) A_c = C_inv @@ -627,7 +620,7 @@ class UniPC: return x_t, model_t def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: B(h))') + #print(f'using unified predictor-corrector with order {order} (solver type: B(h))') ns = self.noise_schedule assert order <= len(model_prev_list) dims = x.dim() @@ -695,7 +688,7 @@ class UniPC: D1s = None if use_corrector: - print('using corrector') + #print('using corrector') # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], device=b.device) @@ -755,8 +748,9 @@ class UniPC: t_T = self.noise_schedule.T if t_start is None else t_start device = x.device if method == 'multistep': - assert steps >= order + assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) @@ -768,6 +762,8 @@ class UniPC: x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True) if model_x is None: model_x = self.model_fn(x, vec_t) + if self.after_update is not None: + self.after_update(x, model_x) model_prev_list.append(model_x) t_prev_list.append(vec_t) for step in range(order, steps + 1): @@ -776,13 +772,15 @@ class UniPC: step_order = min(order, steps + 1 - step) else: step_order = order - print('this step order:', step_order) + #print('this step order:', step_order) if step == steps: - print('do not run corrector at the last step') + #print('do not run corrector at the last step') use_corrector = False else: use_corrector = True x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector) + if self.after_update is not None: + self.after_update(x, model_x) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 86fa1c5b..946079ae 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler: return x, ts, cond, unconditional_conditioning - def after_sample(self, x, ts, cond, uncond, res): - if self.is_unipc: - # unipc model_fn returns (pred_x0) - # p_sample_ddim returns (x_prev, pred_x0) - res = (None, res[0]) - + def update_step(self, last_latent): if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * res[1] + self.last_latent = self.init_latent * self.mask + self.nmask * last_latent else: - self.last_latent = res[1] + self.last_latent = last_latent sd_samplers_common.store_latent(self.last_latent) @@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step shared.total_tqdm.update() + def after_sample(self, x, ts, cond, uncond, res): + if not self.is_unipc: + self.update_step(res[1]) + return x, ts, cond, uncond, res + def unipc_after_update(self, x, model_x): + self.update_step(x) + def initialize(self, p): self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim if self.eta != 0.0: @@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler: if hasattr(self.sampler, fieldname): setattr(self.sampler, fieldname, self.p_sample_ddim_hook) if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r)) + self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None From c88dcc20d495dab4be2692bdff30277112dbe416 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:00:09 -0800 Subject: [PATCH 3/8] UniPC does not support img2img (for now) --- modules/processing.py | 2 +- modules/sd_samplers.py | 2 +- modules/sd_samplers_compvis.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 11e726df..b7cf5357 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 28c2136f..ff361f22 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,7 +32,7 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS']) + hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) samplers = [x for x in all_samplers if x.name not in hidden] samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 946079ae..ad39ab2b 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -139,7 +139,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM' or self.config.name == "UniPC") and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 From 79ffb9453f8eddbdd4e316b9d9c75812b0eea4e1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:27:05 -0800 Subject: [PATCH 4/8] Add UniPC sampler settings --- modules/models/diffusion/uni_pc/sampler.py | 5 +++-- modules/models/diffusion/uni_pc/uni_pc.py | 2 +- modules/shared.py | 5 +++++ scripts/xyz_grid.py | 7 +++++++ 4 files changed, 16 insertions(+), 3 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 219e9862..e66a21e3 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -3,6 +3,7 @@ import torch from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC +from modules import shared class UniPCSampler(object): def __init__(self, model, **kwargs): @@ -89,7 +90,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) - x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/models/diffusion/uni_pc/uni_pc.py b/modules/models/diffusion/uni_pc/uni_pc.py index 31ee81a6..df63d1bc 100644 --- a/modules/models/diffusion/uni_pc/uni_pc.py +++ b/modules/models/diffusion/uni_pc/uni_pc.py @@ -750,7 +750,7 @@ class UniPC: if method == 'multistep': assert steps >= order, "UniPC order must be < sampling steps" timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps") + print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps, order {order}") assert timesteps.shape[0] - 1 == steps with torch.no_grad(): vec_t = timesteps[0].expand((x.shape[0])) diff --git a/modules/shared.py b/modules/shared.py index 79fbf724..34242073 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -480,6 +480,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), + 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 5982cfba..72421e0c 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -126,6 +126,10 @@ def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): p.styles.extend(x.split(',')) +def apply_uni_pc_order(p, x, xs): + opts.data["uni_pc_order"] = min(x, p.steps - 1) + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -202,6 +206,7 @@ axis_options = [ AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), + AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), ] @@ -310,9 +315,11 @@ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.vae = opts.sd_vae + self.uni_pc_order = opts.uni_pc_order def __exit__(self, exc_type, exc_value, tb): opts.data["sd_vae"] = self.vae + opts.data["uni_pc_order"] = self.uni_pc_order modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() From 06cb0dc92095647e4856be10b4d7dc12f5e11fa1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 05:36:41 -0800 Subject: [PATCH 5/8] Fix UniPC order --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 34242073..670d4954 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -482,7 +482,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), - 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 150 - 1, "step": 1}), + 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), })) From fb274229b2c5c1a89dac0b3da28c08c92d71fd95 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Feb 2023 14:30:35 -0800 Subject: [PATCH 6/8] bug fix --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/processing.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index e66a21e3..0bef6eed 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -70,7 +70,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for UniPC sampling is {size}, eta {eta}') + print(f'Data shape for UniPC sampling is {size}') device = self.model.betas.device if x_T is None: diff --git a/modules/processing.py b/modules/processing.py index b7cf5357..0ca15491 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -884,7 +884,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - img2img_sampler_name = 'DDIM' # PLMS/UniPC does not support img2img so we just silently switch ot DDIM + img2img_sampler_name = self.sampler_name + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] From 716a69237cefb385f71105dbbf50e92d664e0f42 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 11 Feb 2023 06:18:34 -0800 Subject: [PATCH 7/8] support SD2.X models --- modules/models/diffusion/uni_pc/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 0bef6eed..708a9b2b 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -80,10 +80,13 @@ class UniPCSampler(object): ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + # SD 1.X is "noise", SD 2.X is "v" + model_type = "v" if self.model.parameterization == "v" else "noise" + model_fn = model_wrapper( lambda x, t, c: self.model.apply_model(x, t, c), ns, - model_type="noise", + model_type=model_type, guidance_type="classifier-free", #condition=conditioning, #unconditional_condition=unconditional_conditioning, From 5fef67f6ee949a61826a3a043ea8610fd89fc371 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 10 Mar 2023 19:56:14 -0500 Subject: [PATCH 8/8] Requested changes --- modules/models/diffusion/uni_pc/sampler.py | 2 +- modules/sd_samplers_compvis.py | 4 +++- modules/shared.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 708a9b2b..6bb3bb21 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -93,7 +93,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=shared.opts.uni_pc_thresholding, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) return x.to(device), None diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index ad39ab2b..7d07c4a5 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -140,10 +140,12 @@ class VanillaStableDiffusionSampler: def adjust_steps_if_invalid(self, p, num_steps): if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): + if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: + num_steps = shared.opts.uni_pc_order valid_step = 999 / (1000 // num_steps) if valid_step == math.floor(valid_step): return int(valid_step) + 1 - + return num_steps def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): diff --git a/modules/shared.py b/modules/shared.py index 7c559fa4..29f8dccb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -485,10 +485,9 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), - 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "vary_coeff"]}), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}), 'uni_pc_order': OptionInfo(3, "UniPC order (must be < sampling steps)", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}), - 'uni_pc_thresholding': OptionInfo(False, "UniPC thresholding"), 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), }))