Merge pull request #8515 from EllangoK/unipc-typo
Fix dims typo in unipc
This commit is contained in:
commit
beb96bd115
@ -719,7 +719,7 @@ class UniPC:
|
|||||||
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
||||||
else:
|
else:
|
||||||
x_t_ = (
|
x_t_ = (
|
||||||
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
|
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
||||||
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
||||||
)
|
)
|
||||||
if x_t is None:
|
if x_t is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user