Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication.

This commit is contained in:
Ryan Dick 2024-02-14 18:17:46 -05:00
parent d87ff3a206
commit 8721926f14

View File

@ -232,28 +232,16 @@ class InvokeAIDiffuserComponent:
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count percent_through = step_index / total_step_count
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step( cross_attention_control_types_to_do = (
percent_through self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
) )
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0 wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
if wants_cross_attention_control: if wants_cross_attention_control or self.sequential_guidance:
( # If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
unconditioned_next_x, # control is currently only supported in sequential mode.
conditioned_next_x,
) = self._apply_cross_attention_controlled_conditioning(
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
)
elif self.sequential_guidance:
( (
unconditioned_next_x, unconditioned_next_x,
conditioned_next_x, conditioned_next_x,
@ -261,6 +249,7 @@ class InvokeAIDiffuserComponent:
x=sample, x=sample,
sigma=timestep, sigma=timestep,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals, down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -342,7 +331,15 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData): def _apply_standard_conditioning(
self,
x,
sigma,
conditioning_data: ConditioningData,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage. the cost of higher memory usage.
""" """
@ -390,6 +387,9 @@ class InvokeAIDiffuserComponent:
both_conditionings, both_conditionings,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
) )
unconditioned_next_x, conditioned_next_x = both_results.chunk(2) unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
@ -400,6 +400,7 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -407,7 +408,8 @@ class InvokeAIDiffuserComponent:
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed. slower execution speed.
""" """
# low-memory sequential path # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
if down_block_additional_residuals is not None: if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], [] uncond_down_block, cond_down_block = [], []
@ -428,8 +430,26 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None: if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising. # If cross-attention control is enabled, prepare the SwapCrossAttnContext.
cross_attn_processor_context = None
if self.cross_attention_control_context is not None:
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
# will be populated before the conditioned pass.
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
index_map=self.cross_attention_control_context.cross_attention_index_map,
mask=self.cross_attention_control_context.cross_attention_mask,
cross_attention_types_to_do=[],
)
#####################
# Unconditioned pass
#####################
cross_attention_kwargs = None cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {
@ -439,6 +459,11 @@ class InvokeAIDiffuserComponent:
] ]
} }
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
@ -447,6 +472,7 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, "time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
} }
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
@ -458,8 +484,13 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
) )
# Run conditional UNet denoising. ###################
# Conditioned pass
###################
cross_attention_kwargs = None cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {
@ -469,6 +500,12 @@ class InvokeAIDiffuserComponent:
] ]
} }
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
@ -476,6 +513,7 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.text_embeddings.add_time_ids,
} }
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
@ -488,85 +526,6 @@ class InvokeAIDiffuserComponent:
) )
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def _apply_cross_attention_controlled_conditioning(
self,
x: torch.Tensor,
sigma,
conditioning_data,
cross_attention_control_types_to_do,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
context: Context = self.cross_attention_control_context
uncond_down_block, cond_down_block = None, None
if down_block_additional_residuals is not None:
uncond_down_block, cond_down_block = [], []
for down_block in down_block_additional_residuals:
_uncond_down, _cond_down = down_block.chunk(2)
uncond_down_block.append(_uncond_down)
cond_down_block.append(_cond_down)
uncond_down_intrablock, cond_down_intrablock = None, None
if down_intrablock_additional_residuals is not None:
uncond_down_intrablock, cond_down_intrablock = [], []
for down_intrablock in down_intrablock_additional_residuals:
_uncond_down, _cond_down = down_intrablock.chunk(2)
uncond_down_intrablock.append(_uncond_down)
cond_down_intrablock.append(_cond_down)
uncond_mid_block, cond_mid_block = None, None
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
cross_attn_processor_context = SwapCrossAttnContext(
modified_text_embeddings=context.arguments.edited_conditioning,
index_map=context.cross_attention_index_map,
mask=context.cross_attention_mask,
cross_attention_types_to_do=[],
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
}
# no cross attention for unconditioning (negative prompt)
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
down_intrablock_additional_residuals=uncond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
)
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
}
# do requested cross attention types for conditioning (positive prompt)
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
conditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.text_embeddings.embeds,
{"swap_cross_attn_context": cross_attn_processor_context},
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
down_intrablock_additional_residuals=cond_down_intrablock,
added_cond_kwargs=added_cond_kwargs,
)
return unconditioned_next_x, conditioned_next_x
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
# to scale how much effect conditioning has, calculate the changes it does and then scale that # to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale