some extra lines I forgot to add for previous commit
This commit is contained in:
parent
1d11e89698
commit
737b73a820
@ -156,11 +156,10 @@ class CFGDenoiser(torch.nn.Module):
|
|||||||
if not skip_uncond:
|
if not skip_uncond:
|
||||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
|
||||||
|
|
||||||
|
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||||
if skip_uncond:
|
if skip_uncond:
|
||||||
#x_out = torch.cat([x_out, x_out[0:batch_size]]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
|
||||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
|
||||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||||
x_out = torch.cat([x_out, fake_uncond])
|
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||||
|
|
||||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
|
||||||
cfg_denoised_callback(denoised_params)
|
cfg_denoised_callback(denoised_params)
|
||||||
|
Loading…
Reference in New Issue
Block a user