Add support for multi-image IP-Adapter.

This commit is contained in:
Ryan Dick 2023-10-13 14:44:42 -04:00
parent bf9f7271dd
commit 8464450a53
5 changed files with 37 additions and 12 deletions

View File

@ -32,7 +32,7 @@ class CLIPVisionModelField(BaseModel):
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
@ -56,7 +56,7 @@ class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
)

View File

@ -445,14 +445,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_images = single_ip_adapter.image
if not isinstance(single_ipa_images, list):
single_ipa_images = [single_ipa_images]
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
single_ipa_images, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)

View File

@ -67,6 +67,12 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
@ -127,26 +133,35 @@ class IPAttnProcessor2_0(torch.nn.Module):
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# ip_hidden_state.shape = (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# ip_key.shape, ip_value.shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# ip_key.shape, ip_value.shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# ip_hidden_states.shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# ip_hidden_states.shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj

View File

@ -55,11 +55,11 @@ class PostprocessingSettings:
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""

View File

@ -345,9 +345,12 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
@ -415,9 +418,10 @@ class InvokeAIDiffuserComponent:
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
@ -444,9 +448,10 @@ class InvokeAIDiffuserComponent:
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}