mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for IP-Adapter masks.
This commit is contained in:
parent
2e27ed5f3d
commit
0bdbfd4d1d
@ -1,11 +1,23 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
@ -30,6 +42,11 @@ class IPAdapterField(BaseModel):
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
|
||||
"regions should be set to True.",
|
||||
)
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
@ -52,7 +69,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@ -79,6 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||
)
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this IP-Adapter applies to."
|
||||
)
|
||||
|
||||
@field_validator("weight")
|
||||
@classmethod
|
||||
@ -112,6 +132,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
weight=self.weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
mask=self.mask,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -633,6 +633,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
@ -670,6 +673,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
|
||||
mask = single_ip_adapter.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
@ -677,6 +685,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
mask=mask,
|
||||
)
|
||||
)
|
||||
|
||||
@ -916,6 +925,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
|
@ -52,6 +52,7 @@ class IPAdapterConditioningInfo:
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
mask: torch.Tensor
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = 1.0
|
||||
|
@ -21,7 +21,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_scales: Optional[list[float]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
@ -29,17 +28,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
Args:
|
||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||
for the i'th IP-Adapter.
|
||||
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||
the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
self._ip_adapter_scales = ip_adapter_scales
|
||||
|
||||
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||
if self._ip_adapter_weights is not None:
|
||||
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
@ -84,10 +75,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None and is_cross_attention:
|
||||
assert percent_through is not None
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
)
|
||||
@ -141,9 +132,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||
if self._is_ip_adapter_enabled():
|
||||
assert regional_ip_data is not None
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True
|
||||
):
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_weights)
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
ip_mask = ip_masks[0, ipa_index, ...]
|
||||
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
@ -175,7 +175,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
|
@ -8,8 +8,14 @@ class RegionalIPData:
|
||||
self,
|
||||
image_prompt_embeds: list[torch.Tensor],
|
||||
scales: list[float],
|
||||
masks: list[torch.Tensor],
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `IPAdapterConditioningData` object."""
|
||||
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
||||
|
||||
# The image prompt embeddings.
|
||||
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
@ -18,3 +24,46 @@ class RegionalIPData:
|
||||
# The scales for the IP-Adapter attention.
|
||||
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||
self.scales = scales
|
||||
|
||||
# The IP-Adapter masks.
|
||||
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
||||
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
||||
# regions and 0.0 for excluded regions.
|
||||
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
||||
|
||||
def _prepare_masks(
|
||||
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
||||
) -> dict[int, torch.Tensor]:
|
||||
"""Prepare the masks for the IP-Adapter attention."""
|
||||
# Concatenate the masks so that they can be processed more efficiently.
|
||||
mask_tensor = torch.cat(masks, dim=1)
|
||||
|
||||
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
||||
|
||||
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
b, num_ip_adapters, h, w = mask_tensor.shape
|
||||
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
||||
assert b == 1
|
||||
|
||||
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
||||
# the spatial features.
|
||||
query_seq_len = h * w
|
||||
|
||||
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
||||
# regions to be lost entirely.
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
||||
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2)
|
||||
|
||||
return masks_by_seq_len
|
||||
|
||||
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
||||
"""Get the mask for the given query sequence length."""
|
||||
return self._masks_by_seq_len[query_seq_len]
|
||||
|
@ -286,7 +286,10 @@ class InvokeAIDiffuserComponent:
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
@ -404,7 +407,10 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
@ -449,7 +455,10 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
|
||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
||||
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||
regional_ip_data = RegionalIPData(
|
||||
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||
)
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
|
@ -12,10 +12,6 @@ class UNetAttentionPatcher:
|
||||
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
self._ip_adapters = ip_adapters
|
||||
self._ip_adapter_scales = None
|
||||
|
||||
if self._ip_adapters is not None:
|
||||
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
@ -32,7 +28,6 @@ class UNetAttentionPatcher:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
self._ip_adapter_scales,
|
||||
)
|
||||
return attn_procs
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user