mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for multi-image IP-Adapter.
This commit is contained in:
parent
bf9f7271dd
commit
8464450a53
@ -32,7 +32,7 @@ class CLIPVisionModelField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
image: Union[ImageField, List[ImageField]] = Field(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||||
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
@ -56,7 +56,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).")
|
||||||
ip_adapter_model: IPAdapterModelField = InputField(
|
ip_adapter_model: IPAdapterModelField = InputField(
|
||||||
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
description="The IP-Adapter model.", title="IP-Adapter Model", input=Input.Direct, ui_order=-1
|
||||||
)
|
)
|
||||||
|
@ -445,14 +445,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
|
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||||
|
single_ipa_images = single_ip_adapter.image
|
||||||
|
if not isinstance(single_ipa_images, list):
|
||||||
|
single_ipa_images = [single_ipa_images]
|
||||||
|
|
||||||
|
single_ipa_images = [context.services.images.get_pil_image(image.image_name) for image in single_ipa_images]
|
||||||
|
|
||||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||||
with image_encoder_model_info as image_encoder_model:
|
with image_encoder_model_info as image_encoder_model:
|
||||||
# Get image embeddings from CLIP and ImageProjModel.
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||||
input_image, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
)
|
)
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
conditioning_data.ip_adapter_conditioning.append(
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||||
|
@ -67,6 +67,12 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
temb=None,
|
temb=None,
|
||||||
ip_adapter_image_prompt_embeds=None,
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
|
"""Apply IP-Adapter attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
||||||
|
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
@ -127,26 +133,35 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
|
||||||
# The batch dimensions should match.
|
# The batch dimensions should match.
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
# The channel dimensions should match.
|
# The token_len dimensions should match.
|
||||||
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
|
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||||
|
|
||||||
ip_hidden_states = ipa_embed
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
|
# ip_hidden_state.shape = (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
# ip_key.shape, ip_value.shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
|
# ip_key.shape, ip_value.shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ip_hidden_states.shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# ip_hidden_states.shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
|
@ -55,11 +55,11 @@ class PostprocessingSettings:
|
|||||||
class IPAdapterConditioningInfo:
|
class IPAdapterConditioningInfo:
|
||||||
cond_image_prompt_embeds: torch.Tensor
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
"""IP-Adapter image encoder conditioning embeddings.
|
"""IP-Adapter image encoder conditioning embeddings.
|
||||||
Shape: (batch_size, num_tokens, encoding_dim).
|
Shape: (num_images, num_tokens, encoding_dim).
|
||||||
"""
|
"""
|
||||||
uncond_image_prompt_embeds: torch.Tensor
|
uncond_image_prompt_embeds: torch.Tensor
|
||||||
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||||
Shape: (batch_size, num_tokens, encoding_dim).
|
Shape: (num_images, num_tokens, encoding_dim).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -345,9 +345,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
"ip_adapter_image_prompt_embeds": [
|
"ip_adapter_image_prompt_embeds": [
|
||||||
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
torch.stack(
|
||||||
|
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||||
|
)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -415,9 +418,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Run unconditional UNet denoising.
|
# Run unconditional UNet denoising.
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
"ip_adapter_image_prompt_embeds": [
|
"ip_adapter_image_prompt_embeds": [
|
||||||
ipa_conditioning.uncond_image_prompt_embeds
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -444,9 +448,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Run conditional UNet denoising.
|
# Run conditional UNet denoising.
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
"ip_adapter_image_prompt_embeds": [
|
"ip_adapter_image_prompt_embeds": [
|
||||||
ipa_conditioning.cond_image_prompt_embeds
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user