support SD2.X models

This commit is contained in:
space-nuko 2023-02-11 06:18:34 -08:00
parent fb274229b2
commit 716a69237c

View File

@ -80,10 +80,13 @@ class UniPCSampler(object):
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
# SD 1.X is "noise", SD 2.X is "v"
model_type = "v" if self.model.parameterization == "v" else "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c), lambda x, t, c: self.model.apply_model(x, t, c),
ns, ns,
model_type="noise", model_type=model_type,
guidance_type="classifier-free", guidance_type="classifier-free",
#condition=conditioning, #condition=conditioning,
#unconditional_condition=unconditional_conditioning, #unconditional_condition=unconditional_conditioning,