mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update the diffusion logic to use the new regional prompting feature.
This commit is contained in:
parent
787a085efc
commit
ee34091bdb
@ -404,22 +404,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
|
||||||
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
use_cross_attention_control = (
|
||||||
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
|
use_regional_prompting = (
|
||||||
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_ip_adapter:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_regional_prompting:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_attention_patcher = None
|
||||||
|
self.use_ip_adapter = use_ip_adapter
|
||||||
|
attn_ctx = nullcontext()
|
||||||
|
if use_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
if use_ip_adapter or use_regional_prompting:
|
||||||
elif ip_adapter_data is not None:
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
ip_adapter_unet_patcher = UNetAttentionPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
self.use_ip_adapter = True
|
|
||||||
else:
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -447,7 +460,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
unet_attention_patcher=unet_attention_patcher,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
@ -479,7 +492,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
ip_adapter_unet_patcher: Optional[UNetAttentionPatcher] = None,
|
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -506,10 +519,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
unet_attention_patcher.set_scale(i, weight)
|
||||||
else:
|
else:
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
unet_attention_patcher.set_scale(i, 0.0)
|
||||||
|
|
||||||
# Handle ControlNet(s)
|
# Handle ControlNet(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
|
@ -88,10 +88,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
# Handle regional prompt attention masks.
|
# Handle regional prompt attention masks.
|
||||||
if regional_prompt_data is not None:
|
if regional_prompt_data is not None and is_cross_attention:
|
||||||
assert percent_through is not None
|
assert percent_through is not None
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
if is_cross_attention:
|
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
)
|
)
|
||||||
|
@ -12,8 +12,11 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
|
Range,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -206,9 +209,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = (
|
||||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
)
|
)
|
||||||
@ -225,6 +228,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
|
percent_through=percent_through,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
@ -239,6 +243,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
|
percent_through=percent_through,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -301,6 +306,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
|
percent_through: float,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -311,17 +317,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||||
torch.stack(
|
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
|
||||||
)
|
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if conditioning_data.is_sdxl():
|
if conditioning_data.is_sdxl():
|
||||||
@ -343,6 +345,31 @@ class InvokeAIDiffuserComponent:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||||
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
)
|
)
|
||||||
@ -366,6 +393,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma,
|
sigma,
|
||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
|
percent_through: float,
|
||||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
@ -413,21 +441,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
@ -437,6 +463,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
|
if conditioning_data.uncond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
@ -453,22 +486,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Conditioned pass
|
# Conditioned pass
|
||||||
###################
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
|
||||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
@ -478,6 +509,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the conditioned pass.
|
||||||
|
if conditioning_data.cond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
|
Loading…
Reference in New Issue
Block a user