Add support for IP-Adapter masks.

This commit is contained in:
Ryan Dick
2024-03-14 16:58:11 -04:00
committed by Kent Keirsey
parent 2e27ed5f3d
commit 0bdbfd4d1d
7 changed files with 113 additions and 26 deletions

View File

@ -52,6 +52,7 @@ class IPAdapterConditioningInfo:
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
mask: torch.Tensor
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0

View File

@ -21,7 +21,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __init__(
self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_scales: Optional[list[float]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
@ -29,17 +28,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
Args:
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
for the i'th IP-Adapter.
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_weights = ip_adapter_weights
self._ip_adapter_scales = ip_adapter_scales
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
if self._ip_adapter_weights is not None:
assert len(ip_adapter_weights) == len(ip_adapter_scales)
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
@ -84,10 +75,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape
# Handle regional prompt attention masks.
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
@ -141,9 +132,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
if is_cross_attention and self._is_ip_adapter_enabled():
if self._is_ip_adapter_enabled():
assert regional_ip_data is not None
for ipa_embed, ipa_weights, scale in zip(
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True
):
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
assert (
len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_weights)
== len(regional_ip_data.scales)
== ip_masks.shape[1]
)
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_weights[ipa_index]
ipa_scale = regional_ip_data.scales[ipa_index]
ip_mask = ip_masks[0, ipa_index, ...]
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
@ -175,7 +175,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None

View File

@ -8,8 +8,14 @@ class RegionalIPData:
self,
image_prompt_embeds: list[torch.Tensor],
scales: list[float],
masks: list[torch.Tensor],
dtype: torch.dtype,
device: torch.device,
max_downscale_factor: int = 8,
):
"""Initialize a `IPAdapterConditioningData` object."""
assert len(image_prompt_embeds) == len(scales) == len(masks)
# The image prompt embeddings.
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
@ -18,3 +24,46 @@ class RegionalIPData:
# The scales for the IP-Adapter attention.
# scales[i] contains the attention scale for the i'th IP-Adapter.
self.scales = scales
# The IP-Adapter masks.
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
# regions and 0.0 for excluded regions.
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
def _prepare_masks(
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
) -> dict[int, torch.Tensor]:
"""Prepare the masks for the IP-Adapter attention."""
# Concatenate the masks so that they can be processed more efficiently.
mask_tensor = torch.cat(masks, dim=1)
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
masks_by_seq_len: dict[int, torch.Tensor] = {}
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, num_ip_adapters, h, w = mask_tensor.shape
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
assert b == 1
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
# the spatial features.
query_seq_len = h * w
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2)
return masks_by_seq_len
def get_masks(self, query_seq_len: int) -> torch.Tensor:
"""Get the mask for the given query sequence length."""
return self._masks_by_seq_len[query_seq_len]

View File

@ -286,7 +286,10 @@ class InvokeAIDiffuserComponent:
for ipa_conditioning in ip_adapter_conditioning
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
@ -404,7 +407,10 @@ class InvokeAIDiffuserComponent:
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the unconditioned pass.
@ -449,7 +455,10 @@ class InvokeAIDiffuserComponent:
]
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
ip_masks = [ipa.mask for ipa in ip_adapter_data]
regional_ip_data = RegionalIPData(
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
)
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
# Prepare SDXL conditioning kwargs for the conditioned pass.

View File

@ -12,10 +12,6 @@ class UNetAttentionPatcher:
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters
self._ip_adapter_scales = None
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
@ -32,7 +28,6 @@ class UNetAttentionPatcher:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._ip_adapter_scales,
)
return attn_procs