import torch from typing import Union, List, Optional, Dict, Any, Tuple from diffusers.models.unet_2d_condition import UNet2DConditionOutput def unet_forward_XTI(self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" Args: sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=self.dtype) emb = self.time_embedding(t_emb) if self.config.num_class_embeds is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb # 2. pre-process sample = self.conv_in(sample) # 3. down down_block_res_samples = (sample,) down_i = 0 for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states[down_i:down_i+2], ) down_i += 2 else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states[6]) # 5. up up_i = 7 for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states[up_i:up_i+3], upsample_size=upsample_size, ) up_i += 3 else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) def downblock_forward_XTI( self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None ): output_states = () i = 0 for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample output_states += (hidden_states,) i += 1 if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states def upblock_forward_XTI( self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None, upsample_size=None, ): i = 0 for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states[i] )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states[i]).sample i += 1 if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states