mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for IP-Adapter masks.
This commit is contained in:
parent
2e27ed5f3d
commit
0bdbfd4d1d
@ -1,11 +1,23 @@
|
|||||||
from builtins import float
|
from builtins import float
|
||||||
from typing import List, Literal, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.fields import (
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
OutputField,
|
||||||
|
TensorField,
|
||||||
|
UIType,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
@ -30,6 +42,11 @@ class IPAdapterField(BaseModel):
|
|||||||
end_step_percent: float = Field(
|
end_step_percent: float = Field(
|
||||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||||
)
|
)
|
||||||
|
mask: Optional[TensorField] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included "
|
||||||
|
"regions should be set to True.",
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("weight")
|
@field_validator("weight")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -52,7 +69,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
|||||||
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
|
||||||
|
|
||||||
|
|
||||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0")
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
"""Collects IP-Adapter info to pass to other nodes."""
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
@ -79,6 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
end_step_percent: float = InputField(
|
end_step_percent: float = InputField(
|
||||||
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
|
||||||
)
|
)
|
||||||
|
mask: Optional[TensorField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this IP-Adapter applies to."
|
||||||
|
)
|
||||||
|
|
||||||
@field_validator("weight")
|
@field_validator("weight")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -112,6 +132,7 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
begin_step_percent=self.begin_step_percent,
|
begin_step_percent=self.begin_step_percent,
|
||||||
end_step_percent=self.end_step_percent,
|
end_step_percent=self.end_step_percent,
|
||||||
|
mask=self.mask,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -633,6 +633,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
) -> Optional[list[IPAdapterData]]:
|
) -> Optional[list[IPAdapterData]]:
|
||||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||||
to the `conditioning_data` (in-place).
|
to the `conditioning_data` (in-place).
|
||||||
@ -670,6 +673,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mask = single_ip_adapter.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
|
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
|
|
||||||
ip_adapter_data_list.append(
|
ip_adapter_data_list.append(
|
||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
ip_adapter_model=ip_adapter_model,
|
ip_adapter_model=ip_adapter_model,
|
||||||
@ -677,6 +685,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
|
mask=mask,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -916,6 +925,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context=context,
|
context=context,
|
||||||
ip_adapter=self.ip_adapter,
|
ip_adapter=self.ip_adapter,
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
dtype=unet.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
|
@ -52,6 +52,7 @@ class IPAdapterConditioningInfo:
|
|||||||
class IPAdapterData:
|
class IPAdapterData:
|
||||||
ip_adapter_model: IPAdapter
|
ip_adapter_model: IPAdapter
|
||||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
|
mask: torch.Tensor
|
||||||
|
|
||||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
weight: Union[float, List[float]] = 1.0
|
weight: Union[float, List[float]] = 1.0
|
||||||
|
@ -21,7 +21,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||||
ip_adapter_scales: Optional[list[float]] = None,
|
|
||||||
):
|
):
|
||||||
"""Initialize a CustomAttnProcessor2_0.
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
@ -29,17 +28,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
Args:
|
Args:
|
||||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||||
for the i'th IP-Adapter.
|
for the i'th IP-Adapter.
|
||||||
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
|
||||||
the i'th IP-Adapter.
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._ip_adapter_weights = ip_adapter_weights
|
self._ip_adapter_weights = ip_adapter_weights
|
||||||
self._ip_adapter_scales = ip_adapter_scales
|
|
||||||
|
|
||||||
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
|
||||||
if self._ip_adapter_weights is not None:
|
|
||||||
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
|
||||||
|
|
||||||
def _is_ip_adapter_enabled(self) -> bool:
|
def _is_ip_adapter_enabled(self) -> bool:
|
||||||
return self._ip_adapter_weights is not None
|
return self._ip_adapter_weights is not None
|
||||||
@ -84,10 +75,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
# Handle regional prompt attention masks.
|
# Handle regional prompt attention masks.
|
||||||
if regional_prompt_data is not None and is_cross_attention:
|
if regional_prompt_data is not None and is_cross_attention:
|
||||||
assert percent_through is not None
|
assert percent_through is not None
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
)
|
)
|
||||||
@ -141,9 +132,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
if is_cross_attention and self._is_ip_adapter_enabled():
|
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||||
if self._is_ip_adapter_enabled():
|
if self._is_ip_adapter_enabled():
|
||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True
|
assert (
|
||||||
):
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
|
== len(self._ip_adapter_weights)
|
||||||
|
== len(regional_ip_data.scales)
|
||||||
|
== ip_masks.shape[1]
|
||||||
|
)
|
||||||
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
|
ipa_weights = self._ip_adapter_weights[ipa_index]
|
||||||
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
# The batch dimensions should match.
|
# The batch dimensions should match.
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
# The token_len dimensions should match.
|
# The token_len dimensions should match.
|
||||||
@ -175,7 +175,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||||
else:
|
else:
|
||||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||||
assert regional_ip_data is None
|
assert regional_ip_data is None
|
||||||
|
@ -8,8 +8,14 @@ class RegionalIPData:
|
|||||||
self,
|
self,
|
||||||
image_prompt_embeds: list[torch.Tensor],
|
image_prompt_embeds: list[torch.Tensor],
|
||||||
scales: list[float],
|
scales: list[float],
|
||||||
|
masks: list[torch.Tensor],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
):
|
):
|
||||||
"""Initialize a `IPAdapterConditioningData` object."""
|
"""Initialize a `IPAdapterConditioningData` object."""
|
||||||
|
assert len(image_prompt_embeds) == len(scales) == len(masks)
|
||||||
|
|
||||||
# The image prompt embeddings.
|
# The image prompt embeddings.
|
||||||
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
# regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor
|
||||||
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
# has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
@ -18,3 +24,46 @@ class RegionalIPData:
|
|||||||
# The scales for the IP-Adapter attention.
|
# The scales for the IP-Adapter attention.
|
||||||
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
# scales[i] contains the attention scale for the i'th IP-Adapter.
|
||||||
self.scales = scales
|
self.scales = scales
|
||||||
|
|
||||||
|
# The IP-Adapter masks.
|
||||||
|
# self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of
|
||||||
|
# s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included
|
||||||
|
# regions and 0.0 for excluded regions.
|
||||||
|
self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype)
|
||||||
|
|
||||||
|
def _prepare_masks(
|
||||||
|
self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype
|
||||||
|
) -> dict[int, torch.Tensor]:
|
||||||
|
"""Prepare the masks for the IP-Adapter attention."""
|
||||||
|
# Concatenate the masks so that they can be processed more efficiently.
|
||||||
|
mask_tensor = torch.cat(masks, dim=1)
|
||||||
|
|
||||||
|
mask_tensor = mask_tensor.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
masks_by_seq_len: dict[int, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
b, num_ip_adapters, h, w = mask_tensor.shape
|
||||||
|
# Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet.
|
||||||
|
assert b == 1
|
||||||
|
|
||||||
|
# The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of
|
||||||
|
# the spatial features.
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1))
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small mask
|
||||||
|
# regions to be lost entirely.
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods.
|
||||||
|
mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return masks_by_seq_len
|
||||||
|
|
||||||
|
def get_masks(self, query_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the mask for the given query sequence length."""
|
||||||
|
return self._masks_by_seq_len[query_seq_len]
|
||||||
|
@ -286,7 +286,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
for ipa_conditioning in ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
|
regional_ip_data = RegionalIPData(
|
||||||
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
@ -404,7 +407,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
|
|
||||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
|
regional_ip_data = RegionalIPData(
|
||||||
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
@ -449,7 +455,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
|
|
||||||
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data]
|
||||||
regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales)
|
ip_masks = [ipa.mask for ipa in ip_adapter_data]
|
||||||
|
regional_ip_data = RegionalIPData(
|
||||||
|
image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device
|
||||||
|
)
|
||||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
|
@ -12,10 +12,6 @@ class UNetAttentionPatcher:
|
|||||||
|
|
||||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapters
|
||||||
self._ip_adapter_scales = None
|
|
||||||
|
|
||||||
if self._ip_adapters is not None:
|
|
||||||
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
@ -32,7 +28,6 @@ class UNetAttentionPatcher:
|
|||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0(
|
attn_procs[name] = CustomAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||||
self._ip_adapter_scales,
|
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user