Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts).

This commit is contained in:
Ryan Dick 2023-09-08 11:47:36 -04:00
parent ddc148b70b
commit b2d5b53b5f
5 changed files with 135 additions and 68 deletions

View File

@ -19,12 +19,42 @@ class AttnProcessor(DiffusersAttnProcessor, nn.Module):
DiffusersAttnProcessor.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor(nn.Module):
r"""
@ -32,21 +62,17 @@ class IPAttnProcessor(nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.image_embedding_len = image_embedding_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -59,7 +85,18 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -86,12 +123,6 @@ class IPAttnProcessor(nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -103,18 +134,18 @@ class IPAttnProcessor(nn.Module):
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -138,16 +169,13 @@ class IPAttnProcessor2_0(torch.nn.Module):
Args:
hidden_size (`int`):
The hidden size of the attention layer.
image_embedding_len (`int`):
The length of the IP-Adapter image embedding. It is assumed that the last `image_embedding_len` 'tokens' of
the `encoder_hidden_states` are the IP-Adapter image embeddings.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, image_embedding_len, cross_attention_dim=None, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@ -155,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.text_context_len = text_context_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -168,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -200,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# Split text encoder hidden states and image encoder hidden state.
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : -self.image_embedding_len, :],
encoder_hidden_states[:, -self.image_embedding_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -226,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@ -92,7 +92,6 @@ class IPAdapter:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size,
image_embedding_len=self.num_tokens,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=torch.float16)

View File

@ -30,6 +30,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
IPAdapterConditioningInfo,
)
from ..util import auto_detect_slice_size, normalize_device
@ -449,27 +450,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
# The following commented block is kept for reference on how to repeat/reshape the image embeddings to
# generate a batch of multiple images:
# bs_embed, seq_len, _ = image_prompt_embeds.shape
# image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
# image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
# uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
text_prompt_embeds = conditioning_data.text_embeddings.embeds
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
print("text embeds shape:", text_prompt_embeds.shape)
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
print("concat embeds shape:", concat_prompt_embeds.shape)
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
else:
image_prompt_embeds = None
uncond_image_prompt_embeds = None
# TODO(ryand): Apply IP-Adapter or custom attention control
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,

View File

@ -51,6 +51,18 @@ class PostprocessingSettings:
v_symmetry_time_pct: Optional[float]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim). Typically: (1, 4, 1024) TODO(ryand): confirm
"""
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
@ -69,6 +81,8 @@ class ConditioningData:
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
@property
def dtype(self):
return self.text_embeddings.dtype

View File

@ -10,6 +10,7 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
@ -232,6 +233,8 @@ class InvokeAIDiffuserComponent:
total_step_count: int,
**kwargs,
):
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@ -339,11 +342,24 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
}
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
@ -371,6 +387,7 @@ class InvokeAIDiffuserComponent:
x_twice,
sigma_twice,
both_conditionings,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
@ -382,9 +399,12 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data,
conditioning_data: ConditioningData,
**kwargs,
):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
# low-memory sequential path
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
@ -400,6 +420,13 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
}
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
@ -412,12 +439,21 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
}
added_cond_kwargs = None
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
@ -428,6 +464,7 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,