Compare commits

...

52 Commits

Author SHA1 Message Date
ff950bc5cd Add support for mask weights, and only mask the tokens associated with the prompts (not eh entire 77-token embedding). 2024-03-07 14:30:51 -05:00
969982b789 Fixup some details of densediffusion for testing. 2024-03-06 19:03:26 -05:00
b8cbff828b wip 2024-03-06 10:52:35 -05:00
d3a40c5b2b Rough hacky implementation of DenseDiffusion. 2024-03-05 18:10:01 -05:00
57266d36a2 Remove dispatch_progress() function that was added aciidentally during conflict resolution. 2024-03-05 15:31:54 -05:00
41e1a9f202 Use the correct device / dtype for RegionalPromptData calculations. 2024-03-05 15:19:58 -05:00
bcfb43e5f0 (minor) Remove commented code. 2024-03-05 09:12:17 -05:00
a665f20fb5 Add positive_self_attn_mask_score and self_attn_adjustment_end_step_percent to the prompt nodes. 2024-03-04 15:34:26 -05:00
d313e5eb70 Remove AddConditioningMaskInvocaton. 2024-03-04 14:11:38 -05:00
271f8f2414 Merge branch 'main' into ryan/regional-conditioning-tuning 2024-03-04 10:52:24 -05:00
5fad379192 Add ability to control regional prompt region weights. 2024-03-03 12:55:07 -05:00
ad18429fe3 Very experimentation with various regional prompting tuning params. 2024-03-02 17:43:21 -05:00
942efa011e Implement (very slow) self-attention regional masking. 2024-03-01 18:43:32 -05:00
ffc4ebb14c Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-conditioning 2024-03-01 11:33:50 -05:00
5b3adf0740 Remove unused code for attention map saving. 2024-02-29 23:42:37 -05:00
a5c94fba43 Delete unused functions from shared_invokeai_diffusion.py. 2024-02-29 23:41:15 -05:00
3e14bd6c45 Remove unused constructor declared with typo in name: __int__. 2024-02-29 22:42:59 -05:00
8721926f14 Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. 2024-02-29 22:42:59 -05:00
d87ff3a206 Remove outdated comments related to T2I-Adapters and ControlNets. 2024-02-29 22:42:59 -05:00
7d9671014b Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. 2024-02-29 22:42:59 -05:00
4a1acd4db9 Fix avoid storing extra conditioning info in two places. 2024-02-29 22:42:59 -05:00
8989a6cdc6 Get multi-prompt attention working simultaneously with IP-adapter. 2024-02-29 14:54:13 -05:00
f44d3da9b1 Add CustomAttnProcessor2_0 class with simultaneous support for IP-Adapter and regional prompting. 2024-02-29 12:48:55 -05:00
1bbd4f751d Fixup logic around compatibility of prompt-to-prompt, IP-Adapter, regional prompting. 2024-02-29 12:47:23 -05:00
bdf3691ad0 Improve the logic for selecting SDXL pooled embeds when handling multi-region prompts. 2024-02-28 22:14:41 -05:00
e7f7ae660d Raise a clear error message if prompt-to-prompt cross-attention control is triggered when using multiple prompts. 2024-02-28 21:38:25 -05:00
e132afb705 Make regional prompting work with sequential conditioning. 2024-02-28 21:21:50 -05:00
5f49e7ae26 Move regional prompt concatenation further up the stack. This solves a number of issues. 2024-02-28 20:11:47 -05:00
53ebca58ff Rename ConditioningData to TextConditioningData. 2024-02-28 13:53:56 -05:00
ee1b3157ce Split ip_adapter_conditioning out from ConditioningData. 2024-02-28 13:49:02 -05:00
e7ec13f209 Remove scheduler_args from ConditioningData structure. 2024-02-28 12:15:39 -05:00
cad3e5dbd7 Remove dead code related to an old symmetry feature. 2024-02-28 11:29:52 -05:00
845c4e93ae Update various comments related to regional prompting, and delete duplicate _preprocess_regional_prompt_mask(...) function. 2024-02-28 10:20:22 -05:00
54971afe44 Add symmetric support for regional negative text prompts. 2024-02-27 20:05:02 -05:00
cfba51aed5 Removed unused function: _prepare_text_embeddings(...) 2024-02-27 19:23:20 -05:00
2966c8de2c Handle conditioned and unconditioned text conditioning in the same way for regional prompt attention. 2024-02-27 18:16:01 -05:00
b0fcbe552e Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation. 2024-02-26 17:34:37 -05:00
d132fb4818 Get RegionalPromptAttnProcessor2_0 working with a ton of hacks. 2024-02-17 19:56:37 -05:00
2d5d370f38 Route masks into the RegionalPromptAttnProcessor2_0 processors. 2024-02-16 19:35:24 -05:00
878bbc3527 Add RectangleMaskInvocation. 2024-02-16 18:03:02 -05:00
caa690e24d Add concatenation of multiple text conditioning tensors, and patching of RegionalPromptAttnProcessor2_0 into the UNet. 2024-02-16 17:09:06 -05:00
38248b988f Fix a minor bug in the logic of the IPAttnProcessor2_0. The change won't have any functional effect, since this attention implementation was only being used for cross-attention, but the logic should be correct now in case we wanted to use it for self-attention. 2024-02-16 09:10:47 -05:00
ba4788007f Initialize a RegionalPromptAttnProcessor2_0 class by copying AttnProcessor2_0 from diffusers. 2024-02-15 17:52:44 -05:00
ef51005881 Remove unused code for attention map saving. 2024-02-15 17:28:55 -05:00
7b0326d7f7 Delete unused functions from shared_invokeai_diffusion.py. 2024-02-15 17:22:37 -05:00
f590b39f88 Add support for a list of ConditioningFields in DenoiseLatents. 2024-02-15 14:41:54 -05:00
58277c6ada Add a mask to the ConditioningField primitive type. 2024-02-15 13:53:32 -05:00
382fa57f3b Remove unused constructor declared with typo in name: __int__. 2024-02-14 18:18:58 -05:00
ee3abc171d Merge sequential conditioning and cac conditioning logic to eliminate a bunch of duplication. 2024-02-14 18:17:46 -05:00
bf72cee555 Remove outdated comments related to T2I-Adapters and ControlNets. 2024-02-14 17:37:40 -05:00
e866e3b19f Remove use of **kwargs in do_unet_step(...), where full parameter list is known and supported. 2024-02-14 17:37:32 -05:00
16e574825c Fix avoid storing extra conditioning info in two places. 2024-02-14 15:34:15 -05:00
13 changed files with 955 additions and 492 deletions

View File

@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list from invokeai.app.util.ti_utils import generate_ti_list
@ -36,7 +44,7 @@ from .model import ClipField
title="Prompt", title="Prompt",
tags=["prompt", "compel"], tags=["prompt", "compel"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.2.0",
) )
class CompelInvocation(BaseInvocation): class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
) )
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
class SDXLPromptInvocationBase: class SDXLPromptInvocationBase:
@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
title="SDXL Prompt", title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"], tags=["sdxl", "compel", "prompt"],
category="conditioning", category="conditioning",
version="1.0.1", version="1.2.0",
) )
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning.""" """Parse prompt using compel package to conditioning."""
@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
mask_weight: float = InputField(default=1.0, description="")
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
c1, c1_pooled, ec1 = self.run_clip_compel( c1, c1_pooled, ec1 = self.run_clip_compel(
@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
mask_weight=self.mask_weight,
)
)
@invocation( @invocation(
@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput.build(conditioning_name) return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
@invocation_output("clip_skip_output") @invocation_output("clip_skip_output")

View File

@ -0,0 +1,40 @@
import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
InvocationContext,
invocation,
)
from invokeai.app.invocations.fields import InputField, WithMetadata
from invokeai.app.invocations.primitives import MaskField, MaskOutput
@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.0",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a rectangular mask."""
height: int = InputField(description="The height of the entire mask.")
width: int = InputField(description="The width of the entire mask.")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True
mask_name = context.tensors.save(mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
width=self.width,
height=self.height,
)

View File

@ -194,6 +194,12 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board") board_id: str = Field(description="The id of the board")
class MaskField(BaseModel):
"""A mask primitive field."""
mask_name: str = Field(description="The name of the mask.")
class DenoiseMaskField(BaseModel): class DenoiseMaskField(BaseModel):
"""An inpaint mask field""" """An inpaint mask field"""
@ -225,7 +231,12 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value""" """A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor") conditioning_name: str = Field(description="The name of conditioning tensor")
# endregion mask: Optional[MaskField] = Field(
default=None,
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
mask_weight: float = Field(description="")
class MetadataField(RootModel): class MetadataField(RootModel):

View File

@ -1,5 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import math import math
from contextlib import ExitStack from contextlib import ExitStack
from functools import singledispatchmethod from functools import singledispatchmethod
@ -9,6 +9,7 @@ import einops
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch import torch
import torchvision
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.configuration_utils import ConfigMixin from diffusers.configuration_utils import ConfigMixin
@ -55,7 +56,14 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
@ -65,7 +73,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
T2IAdapterData, T2IAdapterData,
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import ( from .baseinvocation import (
@ -284,11 +291,11 @@ def get_scheduler(
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
positive_conditioning: ConditioningField = InputField( positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0 description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
) )
negative_conditioning: ConditioningField = InputField( negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1 description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
) )
noise: Optional[LatentsField] = InputField( noise: Optional[LatentsField] = InputField(
default=None, default=None,
@ -365,39 +372,190 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.mask_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
conditioning_fields: list[ConditioningField],
latent_height: int,
latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
extra_conditioning = None
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx]
if (
text_embedding_info.extra_conditioning is not None
and text_embedding_info.extra_conditioning.wants_cross_attention_control
):
extra_conditioning = text_embedding_info.extra_conditioning
if is_sdxl:
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
# global prompt information. In an ideal case, there should be exactly one global prompt without a
# mask, but we don't enforce this.
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
if pooled_embedding is None or mask is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None or mask is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
# embedding_ranges.append(
# Range(
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
# )
# )
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
# in the count.
embedding_ranges.append(
Range(
start=cur_text_embedding_len + 1,
end=cur_text_embedding_len
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
mask_weights=[x.mask_weight for x in conditioning_fields],
)
if extra_conditioning is not None and len(text_conditionings) > 1:
raise ValueError(
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
"prompts."
)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=extra_conditioning,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=extra_conditioning,
), regions
def get_conditioning_data( def get_conditioning_data(
self, self,
context: InvocationContext, context: InvocationContext,
scheduler: Scheduler,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
seed: int, latent_height: int,
) -> ConditioningData: latent_width: int,
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name) ) -> TextConditioningData:
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) # Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name) cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) cond_list, context, unet.device, unet.dtype
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=c,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings(
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None, # v_symmetry_time_pct,
),
) )
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
scheduler, uncond_list, context, unet.device, unet.dtype
# for ddim scheduler )
eta=0.0, # ddim_eta cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
# for ancestral and sde schedulers text_conditionings=cond_text_embeddings,
# flip all bits to have noise different from initial masks=cond_text_embedding_masks,
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF), conditioning_fields=cond_list,
latent_height=latent_height,
latent_width=latent_width,
)
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
conditioning_fields=uncond_list,
latent_height=latent_height,
latent_width=latent_width,
)
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
) )
return conditioning_data return conditioning_data
@ -503,7 +661,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]], ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
conditioning_data: ConditioningData,
exit_stack: ExitStack, exit_stack: ExitStack,
) -> Optional[list[IPAdapterData]]: ) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings """If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
@ -520,7 +677,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
return None return None
ip_adapter_data_list = [] ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter: for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context( ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(key=single_ip_adapter.ip_adapter_model.key) context.models.load(key=single_ip_adapter.ip_adapter_model.key)
@ -543,16 +699,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
single_ipa_images, image_encoder_model single_ipa_images, image_encoder_model
) )
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
ip_adapter_data_list.append( ip_adapter_data_list.append(
IPAdapterData( IPAdapterData(
ip_adapter_model=ip_adapter_model, ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight, weight=single_ip_adapter.weight,
begin_step_percent=single_ip_adapter.begin_step_percent, begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent, end_step_percent=single_ip_adapter.end_step_percent,
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
) )
) )
@ -642,6 +795,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
steps: int, steps: int,
denoising_start: float, denoising_start: float,
denoising_end: float, denoising_end: float,
seed: int,
) -> Tuple[int, List[int], int]: ) -> Tuple[int, List[int], int]:
assert isinstance(scheduler, ConfigMixin) assert isinstance(scheduler, ConfigMixin)
if scheduler.config.get("cpu_only", False): if scheduler.config.get("cpu_only", False):
@ -670,7 +824,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx] timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
num_inference_steps = len(timesteps) // scheduler.order num_inference_steps = len(timesteps) // scheduler.order
return num_inference_steps, timesteps, init_timestep scheduler_step_kwargs = {}
scheduler_step_signature = inspect.signature(scheduler.step)
if "generator" in scheduler_step_signature.parameters:
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
def prep_inpaint_mask( def prep_inpaint_mask(
self, context: InvocationContext, latents: torch.Tensor self, context: InvocationContext, latents: torch.Tensor
@ -763,7 +925,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) _, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
controlnet_data = self.prep_control_data( controlnet_data = self.prep_control_data(
context=context, context=context,
@ -777,16 +942,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data( ip_adapter_data = self.prep_ip_adapter_data(
context=context, context=context,
ip_adapter=self.ip_adapter, ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
exit_stack=exit_stack, exit_stack=exit_stack,
) )
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
steps=self.steps, steps=self.steps,
denoising_start=self.denoising_start, denoising_start=self.denoising_start,
denoising_end=self.denoising_end, denoising_end=self.denoising_end,
seed=seed,
) )
result_latents = pipeline.latents_from_embeddings( result_latents = pipeline.latents_from_embeddings(
@ -799,6 +964,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents, masked_latents=masked_latents,
gradient_mask=gradient_mask, gradient_mask=gradient_mask,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, control_data=controlnet_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,

View File

@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import (
Input, Input,
InputField, InputField,
LatentsField, LatentsField,
MaskField,
OutputField, OutputField,
UIComponent, UIComponent,
) )
@ -229,6 +230,18 @@ class StringCollectionInvocation(BaseInvocation):
# region Image # region Image
@invocation_output("mask_output")
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor.
dtype: torch.bool
shape: (1, height, width).
"""
mask: MaskField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("image_output") @invocation_output("image_output")
class ImageOutput(BaseInvocationOutput): class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image""" """Base class for nodes that output a single image"""
@ -414,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "ConditioningOutput":
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_collection_output") @invocation_output("conditioning_collection_output")
class ConditioningCollectionOutput(BaseInvocationOutput): class ConditioningCollectionOutput(BaseInvocationOutput):

View File

@ -1,182 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
assert len(weights) == len(scales)
self._weights = weights
self._scales = scales
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@ -23,9 +23,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData IPAdapterConditioningInfo,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
@ -170,10 +173,11 @@ class ControlNetData:
@dataclass @dataclass
class IPAdapterData: class IPAdapterData:
ip_adapter_model: IPAdapter = Field(default=None) ip_adapter_model: IPAdapter
# TODO: change to polymorphic so can do different weights per step (once implemented...) ip_adapter_conditioning: IPAdapterConditioningInfo
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = Field(default=1.0) weight: Union[float, List[float]] = Field(default=1.0)
# weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0) begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@ -314,7 +318,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
latents: torch.Tensor, latents: torch.Tensor,
num_inference_steps: int, num_inference_steps: int,
conditioning_data: ConditioningData, scheduler_step_kwargs: dict[str, Any],
conditioning_data: TextConditioningData,
*, *,
noise: Optional[torch.Tensor], noise: Optional[torch.Tensor],
timesteps: torch.Tensor, timesteps: torch.Tensor,
@ -374,6 +379,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents, latents,
timesteps, timesteps,
conditioning_data, conditioning_data,
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
@ -393,7 +399,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
latents: torch.Tensor, latents: torch.Tensor,
timesteps, timesteps,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any],
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
@ -410,22 +417,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents
ip_adapter_unet_patcher = None extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning use_cross_attention_control = (
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
)
if use_cross_attention_control and use_ip_adapter:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
)
if use_cross_attention_control and use_regional_prompting:
raise ValueError(
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
)
unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter
attn_ctx = nullcontext()
if use_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
) )
self.use_ip_adapter = False if use_ip_adapter or use_regional_prompting:
elif ip_adapter_data is not None: ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
# As it is now, the IP-Adapter will silently be skipped. attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
with attn_ctx: with attn_ctx:
if callback is not None: if callback is not None:
@ -448,22 +468,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data, conditioning_data,
step_index=i, step_index=i,
total_step_count=len(timesteps), total_step_count=len(timesteps),
scheduler_step_kwargs=scheduler_step_kwargs,
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data, t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher, unet_attention_patcher=unet_attention_patcher,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing(
postprocessing_settings=conditioning_data.postprocessing_settings,
latents=latents,
sigma=batched_t,
step_index=i,
total_step_count=len(timesteps),
)
predicted_original = getattr(step_output, "pred_original_sample", None) predicted_original = getattr(step_output, "pred_original_sample", None)
if callback is not None: if callback is not None:
@ -485,14 +497,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
t: torch.Tensor, t: torch.Tensor,
latents: torch.Tensor, latents: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None, ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None, t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None, unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -515,10 +528,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
if step_index >= first_adapter_step and step_index <= last_adapter_step: if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range. # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight) unet_attention_patcher.set_scale(i, weight)
else: else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect. # Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0) unet_attention_patcher.set_scale(i, 0.0)
# Handle ControlNet(s) # Handle ControlNet(s)
down_block_additional_residuals = None down_block_additional_residuals = None
@ -562,12 +575,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
down_intrablock_additional_residuals = accum_adapter_state down_intrablock_additional_residuals = accum_adapter_state
ip_adapter_conditioning = None
if ip_adapter_data is not None:
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input, sample=latent_model_input,
timestep=t, # TODO: debug how handled batched and non batched timesteps timestep=t, # TODO: debug how handled batched and non batched timesteps
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
@ -587,7 +605,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
# TODO: issue to diffusers? # TODO: issue to diffusers?
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call # undo internal counter increment done by scheduler.step, so timestep can be resolved as before call

View File

@ -1,7 +1,5 @@
import dataclasses from dataclasses import dataclass
import inspect from typing import List, Optional, Union
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch import torch
@ -10,6 +8,11 @@ from .cross_attention_control import Arguments
@dataclass @dataclass
class ExtraConditioningInfo: class ExtraConditioningInfo:
"""Extra conditioning information produced by Compel.
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
"""
tokens_count_including_eos_bos: int tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None cross_attention_control_args: Optional[Arguments] = None
@ -20,6 +23,8 @@ class ExtraConditioningInfo:
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo] extra_conditioning: Optional[ExtraConditioningInfo]
@ -35,6 +40,8 @@ class ConditioningFieldData:
@dataclass @dataclass
class SDXLConditioningInfo(BasicConditioningInfo): class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor add_time_ids: torch.Tensor
@ -44,14 +51,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype) return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass @dataclass
class IPAdapterConditioningInfo: class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor cond_image_prompt_embeds: torch.Tensor
@ -65,41 +64,55 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class Range:
unconditioned_embeddings: BasicConditioningInfo start: int
text_embeddings: BasicConditioningInfo end: int
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property class TextConditioningRegions:
def dtype(self): def __init__(
return self.text_embeddings.dtype self,
masks: torch.Tensor,
ranges: list[Range],
mask_weights: list[float],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
def add_scheduler_args_if_applicable(self, scheduler, **kwargs): # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
scheduler_args = dict(self.scheduler_args) # ranges[i] contains the embedding range for the i'th prompt / mask.
step_method = inspect.signature(scheduler.step) self.ranges = ranges
for name, value in kwargs.items():
try: self.mask_weights = mask_weights
step_method.bind_partial(**{name: value})
except TypeError: assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else: class TextConditioningData:
scheduler_args[name] = value def __init__(
return dataclasses.replace(self, scheduler_args=scheduler_args) self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -0,0 +1,242 @@
import math
from typing import Optional
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
class CustomAttnProcessor2_0(AttnProcessor2_0):
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
This implementation is based on
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
Supported custom features:
- IP-Adapter
- Regional prompt attention
"""
def __init__(
self,
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
ip_adapter_scales: Optional[list[float]] = None,
):
"""Initialize a CustomAttnProcessor2_0.
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
layer-specific are passed to __init__().
Args:
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
for the i'th IP-Adapter.
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
the i'th IP-Adapter.
"""
super().__init__()
self._ip_adapter_weights = ip_adapter_weights
self._ip_adapter_scales = ip_adapter_scales
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
if self._ip_adapter_weights is not None:
assert len(ip_adapter_weights) == len(ip_adapter_scales)
def _is_ip_adapter_enabled(self) -> bool:
return self._ip_adapter_weights is not None
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
# For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[float] = None,
# For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor:
"""Apply attention.
Args:
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
apply regional prompt masking.
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
# If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
# Handle regional prompt attention masks.
if regional_prompt_data is not None:
assert percent_through is not None
_, query_seq_len, _ = hidden_states.shape
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len,
percent_through=percent_through,
)
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None and percent_through < 0.3:
# Don't apply to uncond????
prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
prompt_region_attention_mask = prompt_region_attention_mask.view(
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
)
scale_factor = 1 / math.sqrt(query.size(-1))
attn_weight = query @ key.transpose(-2, -1) * scale_factor
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
prompt_region_attention_mask = attn_mask_weight * (
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
else:
pass
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
# Apply IP-Adapter conditioning.
if is_cross_attention and self._is_ip_adapter_enabled():
if self._is_ip_adapter_enabled():
assert ip_adapter_image_prompt_embeds is not None
for ipa_embed, ipa_weights, scale in zip(
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
else:
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
assert ip_adapter_image_prompt_embeds is None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@ -0,0 +1,164 @@
import torch
import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:
def __init__(
self,
regions: list[TextConditioningRegions],
device: torch.device,
dtype: torch.dtype,
max_downscale_factor: int = 8,
):
"""Initialize a `RegionalPromptData` object.
Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x.
"""
self._regions = regions
self._device = device
self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = 0.0
self._size_weight = 1.0
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors."""
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_sample_regions in regions:
batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
b, _num_prompts, h, w = batch_sample_masks.shape
assert b == 1
query_seq_len = h * w
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# potentially use a weighted mask rather than a binary mask.
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
return batch_sample_masks_by_seq_len
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
Returns:
torch.Tensor: The masks.
shape: (batch_size, query_seq_len, key_seq_len).
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# size = size.to(dtype=batch_sample_query_scores.dtype)
# batch_sample_query_mask = batch_sample_query_scores > 0.5
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
mask_weight + self._size_weight * (1 - size)
)
return attn_mask
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
Returns:
torch.Tensor: The masks.
shape: (batch_size, query_seq_len, query_seq_len).
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts):
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
size = prompt_query_mask.sum() / prompt_query_mask.numel()
size = size.to(dtype=prompt_query_mask.dtype)
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask.
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
attn_mask[batch_idx, :, :] = torch.maximum(
attn_mask[batch_idx, :, :],
prompt_query_mask.unsqueeze(0)
* prompt_query_mask.unsqueeze(1)
* (mask_weight + self._size_weight * (1 - size)),
)
# if attn_mask[batch_idx].max() < 0.01:
# attn_mask[batch_idx, ...] = 1.0
# attn_mask[attn_mask > 0.5] = 1.0
# attn_mask[attn_mask <= 0.5] = 0.0
# attn_mask_min = attn_mask[batch_idx].min()
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
# if abs(attn_mask_min) > 0.0001:
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
return attn_mask

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
@ -10,11 +11,13 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo, ExtraConditioningInfo,
PostprocessingSettings, IPAdapterConditioningInfo,
SDXLConditioningInfo, Range,
TextConditioningData,
TextConditioningRegions,
) )
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
from .cross_attention_control import ( from .cross_attention_control import (
CrossAttentionType, CrossAttentionType,
@ -56,7 +59,6 @@ class InvokeAIDiffuserComponent:
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
""" """
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
self.conditioning = None
self.model = model self.model = model
self.model_forward_callback = model_forward_callback self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None self.cross_attention_control_context = None
@ -91,7 +93,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor, timestep: torch.Tensor,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
conditioning_data, conditioning_data: TextConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
@ -124,38 +126,30 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? conditioning_data.uncond_text.pooled_embeds,
conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.cond_text.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0, dim=0,
), ),
} }
( (encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
encoder_hidden_states, conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
encoder_attention_mask,
) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds,
) )
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
@ -199,16 +193,17 @@ class InvokeAIDiffuserComponent:
self, self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
): ):
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
percent_through = step_index / total_step_count
cross_attention_control_types_to_do = ( cross_attention_control_types_to_do = (
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through) self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
) )
@ -224,6 +219,8 @@ class InvokeAIDiffuserComponent:
x=sample, x=sample,
sigma=timestep, sigma=timestep,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
ip_adapter_conditioning=ip_adapter_conditioning,
percent_through=percent_through,
cross_attention_control_types_to_do=cross_attention_control_types_to_do, cross_attention_control_types_to_do=cross_attention_control_types_to_do,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
@ -237,6 +234,8 @@ class InvokeAIDiffuserComponent:
x=sample, x=sample,
sigma=timestep, sigma=timestep,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
percent_through=percent_through,
ip_adapter_conditioning=ip_adapter_conditioning,
down_block_additional_residuals=down_block_additional_residuals, down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual, mid_block_additional_residual=mid_block_additional_residual,
down_intrablock_additional_residuals=down_intrablock_additional_residuals, down_intrablock_additional_residuals=down_intrablock_additional_residuals,
@ -244,19 +243,6 @@ class InvokeAIDiffuserComponent:
return unconditioned_next_x, conditioned_next_x return unconditioned_next_x, conditioned_next_x
def do_latent_postprocessing(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
sigma,
step_index,
total_step_count,
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = step_index / total_step_count
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
return latents
def _concat_conditionings_for_batch(self, unconditioning, conditioning): def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _pad_conditioning(cond, target_len, encoder_attention_mask): def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones( conditioning_attention_mask = torch.ones(
@ -304,13 +290,13 @@ class InvokeAIDiffuserComponent:
return torch.cat([unconditioning, conditioning]), encoder_attention_mask return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning( def _apply_standard_conditioning(
self, self,
x, x,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -321,41 +307,55 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None cross_attention_kwargs = {}
if conditioning_data.ip_adapter_conditioning is not None: if ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
"ip_adapter_image_prompt_embeds": [ torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
torch.stack( for ipa_conditioning in ip_adapter_conditioning
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds] ]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning uncond_text = conditioning_data.uncond_text
] cond_text = conditioning_data.cond_text
}
added_cond_kwargs = None added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
[ "time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids,
],
dim=0,
),
} }
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds uncond_text.embeds, cond_text.embeds
) )
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
mask_weights=[0.0],
)
regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
time.sleep(1.0)
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
sigma_twice, sigma_twice,
@ -374,8 +374,10 @@ class InvokeAIDiffuserComponent:
self, self,
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType], cross_attention_control_types_to_do: list[CrossAttentionType],
percent_through: float,
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
@ -422,36 +424,40 @@ class InvokeAIDiffuserComponent:
# Unconditioned pass # Unconditioned pass
##################### #####################
cross_attention_kwargs = None cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
"ip_adapter_image_prompt_embeds": [ torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) for ipa_conditioning in ip_adapter_conditioning
for ipa_conditioning in conditioning_data.ip_adapter_conditioning ]
]
}
# Prepare cross-attention control kwargs for the unconditioned pass. # Prepare cross-attention control kwargs for the unconditioned pass.
if cross_attn_processor_context is not None: if cross_attn_processor_context is not None:
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if conditioning_data.is_sdxl():
if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, "time_ids": conditioning_data.uncond_text.add_time_ids,
} }
# Prepare prompt regions for the unconditioned pass.
if conditioning_data.uncond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run unconditioned UNet denoising (i.e. negative prompt). # Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
@ -463,36 +469,41 @@ class InvokeAIDiffuserComponent:
# Conditioned pass # Conditioned pass
################### ###################
cross_attention_kwargs = None cross_attention_kwargs = {}
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass. # Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
if conditioning_data.ip_adapter_conditioning is not None: if ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
"ip_adapter_image_prompt_embeds": [ torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) for ipa_conditioning in ip_adapter_conditioning
for ipa_conditioning in conditioning_data.ip_adapter_conditioning ]
]
}
# Prepare cross-attention control kwargs for the conditioned pass. # Prepare cross-attention control kwargs for the conditioned pass.
if cross_attn_processor_context is not None: if cross_attn_processor_context is not None:
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context} cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
# Prepare SDXL conditioning kwargs for the conditioned pass. # Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
# Prepare prompt regions for the conditioned pass.
if conditioning_data.cond_regions is not None:
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
)
cross_attention_kwargs["percent_through"] = percent_through
# Run conditioned UNet denoising (i.e. positive prompt). # Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.text_embeddings.embeds, conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,
@ -506,64 +517,3 @@ class InvokeAIDiffuserComponent:
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x return combined_next_x
def apply_symmetry(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through: float,
) -> torch.Tensor:
# Reset our last percent through if this is our first step.
if percent_through == 0.0:
self.last_percent_through = 0.0
if postprocessing_settings is None:
return latents
# Check for out of bounds
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
h_symmetry_time_pct = None
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
v_symmetry_time_pct = None
dev = latents.device.type
latents.to(device="cpu")
if (
h_symmetry_time_pct is not None
and self.last_percent_through < h_symmetry_time_pct
and percent_through >= h_symmetry_time_pct
):
# Horizontal symmetry occurs on the 3rd dimension of the latent
width = latents.shape[3]
x_flipped = torch.flip(latents, dims=[3])
latents = torch.cat(
[
latents[:, :, :, 0 : int(width / 2)],
x_flipped[:, :, :, int(width / 2) : int(width)],
],
dim=3,
)
if (
v_symmetry_time_pct is not None
and self.last_percent_through < v_symmetry_time_pct
and percent_through >= v_symmetry_time_pct
):
# Vertical symmetry occurs on the 2nd dimension of the latent
height = latents.shape[2]
y_flipped = torch.flip(latents, dims=[2])
latents = torch.cat(
[
latents[:, :, 0 : int(height / 2)],
y_flipped[:, :, int(height / 2) : int(height)],
],
dim=2,
)
self.last_percent_through = percent_through
return latents.to(device=dev)

View File

@ -1,52 +1,55 @@
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
class UNetPatcher: class UNetAttentionPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet.""" """A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
def __init__(self, ip_adapters: list[IPAdapter]): def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
self._ip_adapters = ip_adapters self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters) self._ip_adapter_scales = None
if self._ip_adapters is not None:
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float): def set_scale(self, idx: int, value: float):
self._scales[idx] = value self._ip_adapter_scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel): def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them. weights into them (if IP-Adapters are being applied).
Note that the `unet` param is only used to determine attention block dimensions and naming. Note that the `unet` param is only used to determine attention block dimensions and naming.
""" """
# Construct a dict of attention processors based on the UNet's architecture. # Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {} attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()): for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"): if name.endswith("attn1.processor") or self._ip_adapters is None:
attn_procs[name] = AttnProcessor2_0() # "attn1" processors do not use IP-Adapters.
attn_procs[name] = CustomAttnProcessor2_0()
else: else:
# Collect the weights from each IP Adapter for the idx'th attention processor. # Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0( attn_procs[name] = CustomAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters], [ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales, self._ip_adapter_scales,
) )
return attn_procs return attn_procs
@contextmanager @contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors.""" """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
attn_procs = self._prepare_attention_processors(unet) attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors orig_attn_processors = unet.attn_processors
try: try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs) unet.set_attn_processor(attn_procs)
yield None yield None
finally: finally:

View File

@ -1,8 +1,8 @@
import pytest import pytest
import torch import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model from invokeai.backend.util.test_utils import install_and_load_model
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device) ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]} cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter]) ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet): with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample