Add support for IP-Adapter masks.

This commit is contained in:
Ryan Dick 2024-03-14 16:58:11 -04:00 committed by Kent Keirsey
parent 2e27ed5f3d
commit 0bdbfd4d1d
7 changed files with 113 additions and 26 deletions

View File

@ -1,11 +1,23 @@
from builtins import float from builtins import float
from typing import List, Literal, Union from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIType,
)
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
@ -30,6 +42,11 @@ class IPAdapterField(BaseModel):
end_step_percent: float = Field( end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
) )
mask: Optional[TensorField] = Field(
default=None,
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
"regions should be set to True.",
)
@field_validator("weight") @field_validator("weight")
@classmethod @classmethod
@ -52,7 +69,7 @@ class IPAdapterOutput(BaseInvocationOutput):
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
@ -79,6 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
end_step_percent: float = InputField( end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
) )
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this IP-Adapter applies to."
)
@field_validator("weight") @field_validator("weight")
@classmethod @classmethod
@ -112,6 +132,7 @@ class IPAdapterInvocation(BaseInvocation):
weight=self.weight, weight=self.weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
mask=self.mask,
), ),
) )

View File

@ -633,6 +633,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
exit_stack: ExitStack, exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[list[IPAdapterData]]: ) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place). to the `conditioning_data` (in-place).
@ -670,6 +673,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images, image_encoder_model single_ipa_images, image_encoder_model
) )
mask = single_ip_adapter.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append( ip_adapter_data_list.append(
IPAdapterData( IPAdapterData(
ip_adapter_model=ip_adapter_model, ip_adapter_model=ip_adapter_model,
@ -677,6 +685,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
begin_step_percent=single_ip_adapter.begin_step_percent, begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent, end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
mask=mask,
) )
) )
@ -916,6 +925,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context, context=context,
ip_adapter=self.ip_adapter, ip_adapter=self.ip_adapter,
exit_stack=exit_stack, exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
) )
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(

View File

@ -52,6 +52,7 @@ class IPAdapterConditioningInfo:
class IPAdapterData: class IPAdapterData:
ip_adapter_model: IPAdapter ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
# Either a single weight applied to all steps, or a list of weights for each step. # Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0 weight: Union[float, List[float]] = 1.0

View File

@ -21,7 +21,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__( def __init__(
self, self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None, ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_scales: Optional[list[float]] = None,
): ):
"""Initialize a CustomAttnProcessor2_0. """Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@ -29,17 +28,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
Args: Args:
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
for the i'th IP-Adapter. for the i'th IP-Adapter.
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
the i'th IP-Adapter.
""" """
super().__init__() super().__init__()
self._ip_adapter_weights = ip_adapter_weights self._ip_adapter_weights = ip_adapter_weights
self._ip_adapter_scales = ip_adapter_scales
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
if self._ip_adapter_weights is not None:
assert len(ip_adapter_weights) == len(ip_adapter_scales)
def _is_ip_adapter_enabled(self) -> bool: def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None return self._ip_adapter_weights is not None
@ -84,10 +75,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0. # End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape
# Handle regional prompt attention masks. # Handle regional prompt attention masks.
if regional_prompt_data is not None and is_cross_attention: if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length query_seq_len=query_seq_len, key_seq_len=sequence_length
) )
@ -141,9 +132,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
if is_cross_attention and self._is_ip_adapter_enabled(): if is_cross_attention and self._is_ip_adapter_enabled():
if self._is_ip_adapter_enabled(): if self._is_ip_adapter_enabled():
assert regional_ip_data is not None assert regional_ip_data is not None
for ipa_embed, ipa_weights, scale in zip( ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True assert (
): len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index]
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
# The batch dimensions should match. # The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match. # The token_len dimensions should match.
@ -175,7 +175,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else: else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in. # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None assert regional_ip_data is None

View File

@ -8,8 +8,14 @@ class RegionalIPData:
self, self,
image_prompt_embeds: list[torch.Tensor], image_prompt_embeds: list[torch.Tensor],
scales: list[float], scales: list[float],
masks: list[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
max_downscale_factor: int = 8,
): ):
"""Initialize a `IPAdapterConditioningData` object.""" """Initialize a `IPAdapterConditioningData` object."""
assert len(image_prompt_embeds) == len(scales) == len(masks)
# The image prompt embeddings. # The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor # regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). # has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
@ -18,3 +24,46 @@ class RegionalIPData:
# The scales for the IP-Adapter attention. # The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter. # scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales self.scales = scales
# The IP-Adapter masks.
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
# regions and 0.0 for excluded regions.
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
def _prepare_masks(
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
) -> dict[int, torch.Tensor]:
"""Prepare the masks for the IP-Adapter attention."""
# Concatenate the masks so that they can be processed more efficiently.
mask_tensor = torch.cat(masks, dim=1)
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
masks_by_seq_len: dict[int, torch.Tensor] = {}
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, num_ip_adapters, h, w = mask_tensor.shape
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
assert b == 1
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
# the spatial features.
query_seq_len = h * w
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2)
return masks_by_seq_len
def get_masks(self, query_seq_len: int) -> torch.Tensor:
"""Get the mask for the given query sequence length."""
return self._masks_by_seq_len[query_seq_len]

View File

@ -286,7 +286,10 @@ class InvokeAIDiffuserComponent:
for ipa_conditioning in ip_adapter_conditioning for ipa_conditioning in ip_adapter_conditioning
] ]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None added_cond_kwargs = None
@ -404,7 +407,10 @@ class InvokeAIDiffuserComponent:
] ]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
@ -449,7 +455,10 @@ class InvokeAIDiffuserComponent:
] ]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the conditioned pass. # Prepare SDXL conditioning kwargs for the conditioned pass.

View File

@ -12,10 +12,6 @@ class UNetAttentionPatcher:
def __init__(self, ip_adapters: Optional[list[IPAdapter]]): def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters self._ip_adapters = ip_adapters
self._ip_adapter_scales = None
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
def _prepare_attention_processors(self, unet: UNet2DConditionModel): def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@ -32,7 +28,6 @@ class UNetAttentionPatcher:
# Collect the weights from each IP Adapter for the idx'th attention processor. # Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0( attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._ip_adapter_scales,
) )
return attn_procs return attn_procs