From 0bdbfd4d1d6d324e380bc0f1be1fd1797f91f362 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 14 Mar 2024 16:58:11 -0400 Subject: [PATCH] Add support for IP-Adapter masks. --- invokeai/app/invocations/ip_adapter.py | 29 +++++++++-- invokeai/app/invocations/latent.py | 12 +++++ .../diffusion/conditioning_data.py | 1 + .../diffusion/custom_atttention.py | 28 +++++------ .../diffusion/regional_ip_data.py | 49 +++++++++++++++++++ .../diffusion/shared_invokeai_diffusion.py | 15 ++++-- .../diffusion/unet_attention_patcher.py | 5 -- 7 files changed, 113 insertions(+), 26 deletions(-) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index eac6f59199..40cde8f3e9 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -1,11 +1,23 @@ from builtins import float -from typing import List, Literal, Union +from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + OutputField, + TensorField, + UIType, +) from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -30,6 +42,11 @@ class IPAdapterField(BaseModel): end_step_percent: float = Field( default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" ) + mask: Optional[TensorField] = Field( + default=None, + description="The bool mask associated with this IP-Adapter. Excluded regions should be set to False, included " + "regions should be set to True.", + ) @field_validator("weight") @classmethod @@ -52,7 +69,7 @@ class IPAdapterOutput(BaseInvocationOutput): CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.3.0") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -79,6 +96,9 @@ class IPAdapterInvocation(BaseInvocation): end_step_percent: float = InputField( default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)" ) + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this IP-Adapter applies to." + ) @field_validator("weight") @classmethod @@ -112,6 +132,7 @@ class IPAdapterInvocation(BaseInvocation): weight=self.weight, begin_step_percent=self.begin_step_percent, end_step_percent=self.end_step_percent, + mask=self.mask, ), ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ba668440b8..f1e0431b3a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -633,6 +633,9 @@ class DenoiseLatentsInvocation(BaseInvocation): context: InvocationContext, ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], exit_stack: ExitStack, + latent_height: int, + latent_width: int, + dtype: torch.dtype, ) -> Optional[list[IPAdapterData]]: """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings to the `conditioning_data` (in-place). @@ -670,6 +673,11 @@ class DenoiseLatentsInvocation(BaseInvocation): single_ipa_images, image_encoder_model ) + mask = single_ip_adapter.mask + if mask is not None: + mask = context.tensors.load(mask.tensor_name) + mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + ip_adapter_data_list.append( IPAdapterData( ip_adapter_model=ip_adapter_model, @@ -677,6 +685,7 @@ class DenoiseLatentsInvocation(BaseInvocation): begin_step_percent=single_ip_adapter.begin_step_percent, end_step_percent=single_ip_adapter.end_step_percent, ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), + mask=mask, ) ) @@ -916,6 +925,9 @@ class DenoiseLatentsInvocation(BaseInvocation): context=context, ip_adapter=self.ip_adapter, exit_stack=exit_stack, + latent_height=latent_height, + latent_width=latent_width, + dtype=unet.dtype, ) num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 7196802ed3..9b8ea0968a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -52,6 +52,7 @@ class IPAdapterConditioningInfo: class IPAdapterData: ip_adapter_model: IPAdapter ip_adapter_conditioning: IPAdapterConditioningInfo + mask: torch.Tensor # Either a single weight applied to all steps, or a list of weights for each step. weight: Union[float, List[float]] = 1.0 diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 34f868306b..c864669df5 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -21,7 +21,6 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): def __init__( self, ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None, - ip_adapter_scales: Optional[list[float]] = None, ): """Initialize a CustomAttnProcessor2_0. Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are @@ -29,17 +28,9 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): Args: ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights for the i'th IP-Adapter. - ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for - the i'th IP-Adapter. """ super().__init__() - self._ip_adapter_weights = ip_adapter_weights - self._ip_adapter_scales = ip_adapter_scales - - assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None) - if self._ip_adapter_weights is not None: - assert len(ip_adapter_weights) == len(ip_adapter_scales) def _is_ip_adapter_enabled(self) -> bool: return self._ip_adapter_weights is not None @@ -84,10 +75,10 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End unmodified block from AttnProcessor2_0. + _, query_seq_len, _ = hidden_states.shape # Handle regional prompt attention masks. if regional_prompt_data is not None and is_cross_attention: assert percent_through is not None - _, query_seq_len, _ = hidden_states.shape prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( query_seq_len=query_seq_len, key_seq_len=sequence_length ) @@ -141,9 +132,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): if is_cross_attention and self._is_ip_adapter_enabled(): if self._is_ip_adapter_enabled(): assert regional_ip_data is not None - for ipa_embed, ipa_weights, scale in zip( - regional_ip_data.image_prompt_embeds, self._ip_adapter_weights, regional_ip_data.scales, strict=True - ): + ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) + assert ( + len(regional_ip_data.image_prompt_embeds) + == len(self._ip_adapter_weights) + == len(regional_ip_data.scales) + == ip_masks.shape[1] + ) + for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): + ipa_weights = self._ip_adapter_weights[ipa_index] + ipa_scale = regional_ip_data.scales[ipa_index] + ip_mask = ip_masks[0, ipa_index, ...] + # The batch dimensions should match. assert ipa_embed.shape[0] == encoder_hidden_states.shape[0] # The token_len dimensions should match. @@ -175,7 +175,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + scale * ip_hidden_states + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask else: # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. assert regional_ip_data is None diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py index ecf878b416..d3b4505f58 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py @@ -8,8 +8,14 @@ class RegionalIPData: self, image_prompt_embeds: list[torch.Tensor], scales: list[float], + masks: list[torch.Tensor], + dtype: torch.dtype, + device: torch.device, + max_downscale_factor: int = 8, ): """Initialize a `IPAdapterConditioningData` object.""" + assert len(image_prompt_embeds) == len(scales) == len(masks) + # The image prompt embeddings. # regional_ip_data[i] contains the image prompt embeddings for the i'th IP-Adapter. Each tensor # has shape (batch_size, num_ip_images, seq_len, ip_embedding_len). @@ -18,3 +24,46 @@ class RegionalIPData: # The scales for the IP-Adapter attention. # scales[i] contains the attention scale for the i'th IP-Adapter. self.scales = scales + + # The IP-Adapter masks. + # self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of + # s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included + # regions and 0.0 for excluded regions. + self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype) + + def _prepare_masks( + self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype + ) -> dict[int, torch.Tensor]: + """Prepare the masks for the IP-Adapter attention.""" + # Concatenate the masks so that they can be processed more efficiently. + mask_tensor = torch.cat(masks, dim=1) + + mask_tensor = mask_tensor.to(device=device, dtype=dtype) + + masks_by_seq_len: dict[int, torch.Tensor] = {} + + # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. + downscale_factor = 1 + while downscale_factor <= max_downscale_factor: + b, num_ip_adapters, h, w = mask_tensor.shape + # Assert that the batch size is 1, because I haven't thought through batch handling for this feature yet. + assert b == 1 + + # The IP-Adapters are applied in the cross-attention layers, where the query sequence length is the h * w of + # the spatial features. + query_seq_len = h * w + + masks_by_seq_len[query_seq_len] = mask_tensor.view((b, num_ip_adapters, -1, 1)) + + downscale_factor *= 2 + if downscale_factor <= max_downscale_factor: + # We use max pooling because we downscale to a pretty low resolution, so we don't want small mask + # regions to be lost entirely. + # TODO(ryand): In the future, we may want to experiment with other downsampling methods. + mask_tensor = torch.nn.functional.max_pool2d(mask_tensor, kernel_size=2, stride=2) + + return masks_by_seq_len + + def get_masks(self, query_seq_len: int) -> torch.Tensor: + """Get the mask for the given query sequence length.""" + return self._masks_by_seq_len[query_seq_len] diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 05b4a6406d..f418133e49 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -286,7 +286,10 @@ class InvokeAIDiffuserComponent: for ipa_conditioning in ip_adapter_conditioning ] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + ip_masks = [ipa.mask for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData( + image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device + ) cross_attention_kwargs["regional_ip_data"] = regional_ip_data added_cond_kwargs = None @@ -404,7 +407,10 @@ class InvokeAIDiffuserComponent: ] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + ip_masks = [ipa.mask for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData( + image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device + ) cross_attention_kwargs["regional_ip_data"] = regional_ip_data # Prepare SDXL conditioning kwargs for the unconditioned pass. @@ -449,7 +455,10 @@ class InvokeAIDiffuserComponent: ] scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData(image_prompt_embeds=image_prompt_embeds, scales=scales) + ip_masks = [ipa.mask for ipa in ip_adapter_data] + regional_ip_data = RegionalIPData( + image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device + ) cross_attention_kwargs["regional_ip_data"] = regional_ip_data # Prepare SDXL conditioning kwargs for the conditioned pass. diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 416430b525..89a203f643 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -12,10 +12,6 @@ class UNetAttentionPatcher: def __init__(self, ip_adapters: Optional[list[IPAdapter]]): self._ip_adapters = ip_adapters - self._ip_adapter_scales = None - - if self._ip_adapters is not None: - self._ip_adapter_scales = [1.0] * len(self._ip_adapters) def _prepare_attention_processors(self, unet: UNet2DConditionModel): """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention @@ -32,7 +28,6 @@ class UNetAttentionPatcher: # Collect the weights from each IP Adapter for the idx'th attention processor. attn_procs[name] = CustomAttnProcessor2_0( [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], - self._ip_adapter_scales, ) return attn_procs