Update the diffusion logic to use the new regional prompting feature.

This commit is contained in:
Ryan Dick
2024-03-08 14:34:49 -05:00
committed by Kent Keirsey
parent 7ca677578e
commit a78df8123f
3 changed files with 96 additions and 46 deletions

View File

@ -88,13 +88,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# End unmodified block from AttnProcessor2_0.
# Handle regional prompt attention masks.
if regional_prompt_data is not None:
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask

View File

@ -12,8 +12,11 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo,
IPAdapterConditioningInfo,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
from .cross_attention_control import (
CrossAttentionType,
@ -206,9 +209,9 @@ class InvokeAIDiffuserComponent:
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
)
@ -225,6 +228,7 @@ class InvokeAIDiffuserComponent:
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@ -239,6 +243,7 @@ class InvokeAIDiffuserComponent:
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -301,6 +306,7 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -311,17 +317,13 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
cross_attention_kwargs = {}
if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in ip_adapter_conditioning
]
added_cond_kwargs = None
if conditioning_data.is_sdxl():
@ -343,6 +345,31 @@ class InvokeAIDiffuserComponent:
),
}
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
)
@ -366,6 +393,7 @@ class InvokeAIDiffuserComponent:
sigma,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
@ -413,21 +441,19 @@ class InvokeAIDiffuserComponent:
# Unconditioned pass
#####################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
@ -437,6 +463,13 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.uncond_text.add_time_ids,
}
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
@ -453,22 +486,20 @@ class InvokeAIDiffuserComponent:
# Conditioned pass
###################
cross_attention_kwargs = None
cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
}
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in ip_adapter_conditioning
]
# Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
@ -478,6 +509,13 @@ class InvokeAIDiffuserComponent:
"time_ids": conditioning_data.cond_text.add_time_ids,
}
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,