mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication.
This commit is contained in:
parent
d87ff3a206
commit
8721926f14
@ -232,28 +232,16 @@ class InvokeAIDiffuserComponent:
|
|||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = step_index / total_step_count
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
cross_attention_control_types_to_do = (
|
||||||
percent_through
|
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
)
|
)
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
|
|
||||||
if wants_cross_attention_control:
|
if wants_cross_attention_control or self.sequential_guidance:
|
||||||
(
|
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||||
unconditioned_next_x,
|
# control is currently only supported in sequential mode.
|
||||||
conditioned_next_x,
|
|
||||||
) = self._apply_cross_attention_controlled_conditioning(
|
|
||||||
x=sample,
|
|
||||||
sigma=timestep,
|
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
|
||||||
)
|
|
||||||
elif self.sequential_guidance:
|
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
conditioned_next_x,
|
conditioned_next_x,
|
||||||
@ -261,6 +249,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -342,7 +331,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData):
|
def _apply_standard_conditioning(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
sigma,
|
||||||
|
conditioning_data: ConditioningData,
|
||||||
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
|
):
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
the cost of higher memory usage.
|
the cost of higher memory usage.
|
||||||
"""
|
"""
|
||||||
@ -390,6 +387,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
both_conditionings,
|
both_conditionings,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
)
|
)
|
||||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||||
@ -400,6 +400,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
|
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -407,7 +408,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
# low-memory sequential path
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||||
|
# and T2I-Adapter residuals into two chunks.
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
if down_block_additional_residuals is not None:
|
if down_block_additional_residuals is not None:
|
||||||
uncond_down_block, cond_down_block = [], []
|
uncond_down_block, cond_down_block = [], []
|
||||||
@ -428,8 +430,26 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
# Run unconditional UNet denoising.
|
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
||||||
|
cross_attn_processor_context = None
|
||||||
|
if self.cross_attention_control_context is not None:
|
||||||
|
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
||||||
|
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
||||||
|
# will be populated before the conditioned pass.
|
||||||
|
cross_attn_processor_context = SwapCrossAttnContext(
|
||||||
|
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
||||||
|
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
||||||
|
mask=self.cross_attention_control_context.cross_attention_mask,
|
||||||
|
cross_attention_types_to_do=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
#####################
|
||||||
|
# Unconditioned pass
|
||||||
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
|
|
||||||
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
@ -439,6 +459,11 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||||
|
if cross_attn_processor_context is not None:
|
||||||
|
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||||
|
|
||||||
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
@ -447,6 +472,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
@ -458,8 +484,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run conditional UNet denoising.
|
###################
|
||||||
|
# Conditioned pass
|
||||||
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
|
|
||||||
|
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
@ -469,6 +500,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||||
|
if cross_attn_processor_context is not None:
|
||||||
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
|
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
||||||
|
|
||||||
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
@ -476,6 +513,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
@ -488,85 +526,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def _apply_cross_attention_controlled_conditioning(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
conditioning_data,
|
|
||||||
cross_attention_control_types_to_do,
|
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
|
||||||
):
|
|
||||||
context: Context = self.cross_attention_control_context
|
|
||||||
|
|
||||||
uncond_down_block, cond_down_block = None, None
|
|
||||||
if down_block_additional_residuals is not None:
|
|
||||||
uncond_down_block, cond_down_block = [], []
|
|
||||||
for down_block in down_block_additional_residuals:
|
|
||||||
_uncond_down, _cond_down = down_block.chunk(2)
|
|
||||||
uncond_down_block.append(_uncond_down)
|
|
||||||
cond_down_block.append(_cond_down)
|
|
||||||
|
|
||||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
|
||||||
if down_intrablock_additional_residuals is not None:
|
|
||||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
|
||||||
for down_intrablock in down_intrablock_additional_residuals:
|
|
||||||
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
|
||||||
uncond_down_intrablock.append(_uncond_down)
|
|
||||||
cond_down_intrablock.append(_cond_down)
|
|
||||||
|
|
||||||
uncond_mid_block, cond_mid_block = None, None
|
|
||||||
if mid_block_additional_residual is not None:
|
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
|
||||||
|
|
||||||
cross_attn_processor_context = SwapCrossAttnContext(
|
|
||||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
|
||||||
index_map=context.cross_attention_index_map,
|
|
||||||
mask=context.cross_attention_mask,
|
|
||||||
cross_attention_types_to_do=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
added_cond_kwargs = None
|
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
}
|
|
||||||
|
|
||||||
# no cross attention for unconditioning (negative prompt)
|
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
|
||||||
down_block_additional_residuals=uncond_down_block,
|
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
|
||||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
|
||||||
}
|
|
||||||
|
|
||||||
# do requested cross attention types for conditioning (positive prompt)
|
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
|
||||||
conditioned_next_x = self.model_forward_callback(
|
|
||||||
x,
|
|
||||||
sigma,
|
|
||||||
conditioning_data.text_embeddings.embeds,
|
|
||||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
|
||||||
down_block_additional_residuals=cond_down_block,
|
|
||||||
mid_block_additional_residual=cond_mid_block,
|
|
||||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
|
||||||
)
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
|
||||||
|
|
||||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
|
Loading…
Reference in New Issue
Block a user