From 7651eeea8d294ef3da2121e016d413e7c654d96b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 14 Feb 2024 18:17:46 -0500 Subject: [PATCH] Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. --- .../diffusion/shared_invokeai_diffusion.py | 159 +++++++----------- 1 file changed, 59 insertions(+), 100 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 353256006a..2b72c808e4 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -232,28 +232,16 @@ class InvokeAIDiffuserComponent: down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter ): cross_attention_control_types_to_do = [] - context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = step_index / total_step_count - cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( - percent_through + cross_attention_control_types_to_do = ( + self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) ) wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 - if wants_cross_attention_control: - ( - unconditioned_next_x, - conditioned_next_x, - ) = self._apply_cross_attention_controlled_conditioning( - x=sample, - sigma=timestep, - conditioning_data=conditioning_data, - cross_attention_control_types_to_do=cross_attention_control_types_to_do, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - ) - elif self.sequential_guidance: + if wants_cross_attention_control or self.sequential_guidance: + # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention + # control is currently only supported in sequential mode. ( unconditioned_next_x, conditioned_next_x, @@ -261,6 +249,7 @@ class InvokeAIDiffuserComponent: x=sample, sigma=timestep, conditioning_data=conditioning_data, + cross_attention_control_types_to_do=cross_attention_control_types_to_do, down_block_additional_residuals=down_block_additional_residuals, mid_block_additional_residual=mid_block_additional_residual, down_intrablock_additional_residuals=down_intrablock_additional_residuals, @@ -342,7 +331,15 @@ class InvokeAIDiffuserComponent: # methods below are called from do_diffusion_step and should be considered private to this class. - def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData): + def _apply_standard_conditioning( + self, + x, + sigma, + conditioning_data: ConditioningData, + down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet + mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet + down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter + ): """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at the cost of higher memory usage. """ @@ -390,6 +387,9 @@ class InvokeAIDiffuserComponent: both_conditionings, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + down_intrablock_additional_residuals=down_intrablock_additional_residuals, added_cond_kwargs=added_cond_kwargs, ) unconditioned_next_x, conditioned_next_x = both_results.chunk(2) @@ -400,6 +400,7 @@ class InvokeAIDiffuserComponent: x: torch.Tensor, sigma, conditioning_data: ConditioningData, + cross_attention_control_types_to_do: list[CrossAttentionType], down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter @@ -407,7 +408,8 @@ class InvokeAIDiffuserComponent: """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of slower execution speed. """ - # low-memory sequential path + # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet + # and T2I-Adapter residuals into two chunks. uncond_down_block, cond_down_block = None, None if down_block_additional_residuals is not None: uncond_down_block, cond_down_block = [], [] @@ -428,8 +430,26 @@ class InvokeAIDiffuserComponent: if mid_block_additional_residual is not None: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) - # Run unconditional UNet denoising. + # If cross-attention control is enabled, prepare the SwapCrossAttnContext. + cross_attn_processor_context = None + if self.cross_attention_control_context is not None: + # Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do. + # This list is empty because cross-attention control is not applied in the unconditioned pass. This field + # will be populated before the conditioned pass. + cross_attn_processor_context = SwapCrossAttnContext( + modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning, + index_map=self.cross_attention_control_context.cross_attention_index_map, + mask=self.cross_attention_control_context.cross_attention_mask, + cross_attention_types_to_do=[], + ) + + ##################### + # Unconditioned pass + ##################### + cross_attention_kwargs = None + + # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. if conditioning_data.ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { @@ -439,6 +459,11 @@ class InvokeAIDiffuserComponent: ] } + # Prepare cross-attention control kwargs for the unconditioned pass. + if cross_attn_processor_context is not None: + cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + + # Prepare SDXL conditioning kwargs for the unconditioned pass. added_cond_kwargs = None is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: @@ -447,6 +472,7 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, } + # Run unconditioned UNet denoising (i.e. negative prompt). unconditioned_next_x = self.model_forward_callback( x, sigma, @@ -458,8 +484,13 @@ class InvokeAIDiffuserComponent: added_cond_kwargs=added_cond_kwargs, ) - # Run conditional UNet denoising. + ################### + # Conditioned pass + ################### + cross_attention_kwargs = None + + # Prepare IP-Adapter cross-attention kwargs for the conditioned pass. if conditioning_data.ip_adapter_conditioning is not None: # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). cross_attention_kwargs = { @@ -469,6 +500,12 @@ class InvokeAIDiffuserComponent: ] } + # Prepare cross-attention control kwargs for the conditioned pass. + if cross_attn_processor_context is not None: + cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do + cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} + + # Prepare SDXL conditioning kwargs for the conditioned pass. added_cond_kwargs = None if is_sdxl: added_cond_kwargs = { @@ -476,6 +513,7 @@ class InvokeAIDiffuserComponent: "time_ids": conditioning_data.text_embeddings.add_time_ids, } + # Run conditioned UNet denoising (i.e. positive prompt). conditioned_next_x = self.model_forward_callback( x, sigma, @@ -488,85 +526,6 @@ class InvokeAIDiffuserComponent: ) return unconditioned_next_x, conditioned_next_x - def _apply_cross_attention_controlled_conditioning( - self, - x: torch.Tensor, - sigma, - conditioning_data, - cross_attention_control_types_to_do, - down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet - mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet - down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ): - context: Context = self.cross_attention_control_context - - uncond_down_block, cond_down_block = None, None - if down_block_additional_residuals is not None: - uncond_down_block, cond_down_block = [], [] - for down_block in down_block_additional_residuals: - _uncond_down, _cond_down = down_block.chunk(2) - uncond_down_block.append(_uncond_down) - cond_down_block.append(_cond_down) - - uncond_down_intrablock, cond_down_intrablock = None, None - if down_intrablock_additional_residuals is not None: - uncond_down_intrablock, cond_down_intrablock = [], [] - for down_intrablock in down_intrablock_additional_residuals: - _uncond_down, _cond_down = down_intrablock.chunk(2) - uncond_down_intrablock.append(_uncond_down) - cond_down_intrablock.append(_cond_down) - - uncond_mid_block, cond_mid_block = None, None - if mid_block_additional_residual is not None: - uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) - - cross_attn_processor_context = SwapCrossAttnContext( - modified_text_embeddings=context.arguments.edited_conditioning, - index_map=context.cross_attention_index_map, - mask=context.cross_attention_mask, - cross_attention_types_to_do=[], - ) - - added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo - if is_sdxl: - added_cond_kwargs = { - "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, - "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, - } - - # no cross attention for unconditioning (negative prompt) - unconditioned_next_x = self.model_forward_callback( - x, - sigma, - conditioning_data.unconditioned_embeddings.embeds, - {"swap_cross_attn_context": cross_attn_processor_context}, - down_block_additional_residuals=uncond_down_block, - mid_block_additional_residual=uncond_mid_block, - down_intrablock_additional_residuals=uncond_down_intrablock, - added_cond_kwargs=added_cond_kwargs, - ) - - if is_sdxl: - added_cond_kwargs = { - "text_embeds": conditioning_data.text_embeddings.pooled_embeds, - "time_ids": conditioning_data.text_embeddings.add_time_ids, - } - - # do requested cross attention types for conditioning (positive prompt) - cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do - conditioned_next_x = self.model_forward_callback( - x, - sigma, - conditioning_data.text_embeddings.embeds, - {"swap_cross_attn_context": cross_attn_processor_context}, - down_block_additional_residuals=cond_down_block, - mid_block_additional_residual=cond_mid_block, - down_intrablock_additional_residuals=cond_down_intrablock, - added_cond_kwargs=added_cond_kwargs, - ) - return unconditioned_next_x, conditioned_next_x - def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): # to scale how much effect conditioning has, calculate the changes it does and then scale that scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale