Split ip_adapter_conditioning out from ConditioningData.

This commit is contained in:
Ryan Dick
2024-02-28 13:49:02 -05:00
parent e7ec13f209
commit ee1b3157ce
4 changed files with 23 additions and 19 deletions

View File

@ -71,5 +71,3 @@ class ConditioningData:
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None

View File

@ -14,6 +14,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo,
IPAdapterConditioningInfo,
SDXLConditioningInfo,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
@ -329,6 +330,7 @@ class InvokeAIDiffuserComponent:
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data: ConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
step_index: int,
total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
@ -353,6 +355,7 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
@ -366,6 +369,7 @@ class InvokeAIDiffuserComponent:
x=sample,
sigma=timestep,
conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -425,6 +429,7 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data: ConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -483,14 +488,14 @@ class InvokeAIDiffuserComponent:
}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
for ipa_conditioning in ip_adapter_conditioning
]
}
@ -527,6 +532,7 @@ class InvokeAIDiffuserComponent:
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
@ -581,12 +587,12 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
for ipa_conditioning in ip_adapter_conditioning
]
}
@ -622,12 +628,12 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None:
if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
for ipa_conditioning in ip_adapter_conditioning
]
}