mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
52 Commits
lstein/fea
...
ryan/dense
Author | SHA1 | Date | |
---|---|---|---|
ff950bc5cd | |||
969982b789 | |||
b8cbff828b | |||
d3a40c5b2b | |||
57266d36a2 | |||
41e1a9f202 | |||
bcfb43e5f0 | |||
a665f20fb5 | |||
d313e5eb70 | |||
271f8f2414 | |||
5fad379192 | |||
ad18429fe3 | |||
942efa011e | |||
ffc4ebb14c | |||
5b3adf0740 | |||
a5c94fba43 | |||
3e14bd6c45 | |||
8721926f14 | |||
d87ff3a206 | |||
7d9671014b | |||
4a1acd4db9 | |||
8989a6cdc6 | |||
f44d3da9b1 | |||
1bbd4f751d | |||
bdf3691ad0 | |||
e7f7ae660d | |||
e132afb705 | |||
5f49e7ae26 | |||
53ebca58ff | |||
ee1b3157ce | |||
e7ec13f209 | |||
cad3e5dbd7 | |||
845c4e93ae | |||
54971afe44 | |||
cfba51aed5 | |||
2966c8de2c | |||
b0fcbe552e | |||
d132fb4818 | |||
2d5d370f38 | |||
878bbc3527 | |||
caa690e24d | |||
38248b988f | |||
ba4788007f | |||
ef51005881 | |||
7b0326d7f7 | |||
f590b39f88 | |||
58277c6ada | |||
382fa57f3b | |||
ee3abc171d | |||
bf72cee555 | |||
e866e3b19f | |||
16e574825c |
@ -5,7 +5,15 @@ from compel import Compel, ReturnedEmbeddingsType
|
|||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
from invokeai.app.invocations.fields import (
|
||||||
|
ConditioningField,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
MaskField,
|
||||||
|
OutputField,
|
||||||
|
UIComponent,
|
||||||
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.ti_utils import generate_ti_list
|
from invokeai.app.util.ti_utils import generate_ti_list
|
||||||
@ -36,7 +44,7 @@ from .model import ClipField
|
|||||||
title="Prompt",
|
title="Prompt",
|
||||||
tags=["prompt", "compel"],
|
tags=["prompt", "compel"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class CompelInvocation(BaseInvocation):
|
class CompelInvocation(BaseInvocation):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -51,6 +59,10 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
@ -118,7 +130,13 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
mask_weight=self.mask_weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLPromptInvocationBase:
|
class SDXLPromptInvocationBase:
|
||||||
@ -232,7 +250,7 @@ class SDXLPromptInvocationBase:
|
|||||||
title="SDXL Prompt",
|
title="SDXL Prompt",
|
||||||
tags=["sdxl", "compel", "prompt"],
|
tags=["sdxl", "compel", "prompt"],
|
||||||
category="conditioning",
|
category="conditioning",
|
||||||
version="1.0.1",
|
version="1.2.0",
|
||||||
)
|
)
|
||||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||||
"""Parse prompt using compel package to conditioning."""
|
"""Parse prompt using compel package to conditioning."""
|
||||||
@ -256,6 +274,11 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
|
|
||||||
|
mask: Optional[MaskField] = InputField(
|
||||||
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
|
)
|
||||||
|
mask_weight: float = InputField(default=1.0, description="")
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
c1, c1_pooled, ec1 = self.run_clip_compel(
|
c1, c1_pooled, ec1 = self.run_clip_compel(
|
||||||
@ -317,7 +340,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
mask=self.mask,
|
||||||
|
mask_weight=self.mask_weight,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -366,7 +395,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
|||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput.build(conditioning_name)
|
return ConditioningOutput(conditioning=ConditioningField(conditioning_name=conditioning_name, mask_weight=1.0))
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("clip_skip_output")
|
@invocation_output("clip_skip_output")
|
||||||
|
40
invokeai/app/invocations/conditioning.py
Normal file
40
invokeai/app/invocations/conditioning.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
InvocationContext,
|
||||||
|
invocation,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||||
|
from invokeai.app.invocations.primitives import MaskField, MaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
@invocation(
|
||||||
|
"rectangle_mask",
|
||||||
|
title="Create Rectangle Mask",
|
||||||
|
tags=["conditioning"],
|
||||||
|
category="conditioning",
|
||||||
|
version="1.0.0",
|
||||||
|
)
|
||||||
|
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
|
"""Create a rectangular mask."""
|
||||||
|
|
||||||
|
height: int = InputField(description="The height of the entire mask.")
|
||||||
|
width: int = InputField(description="The width of the entire mask.")
|
||||||
|
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||||
|
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||||
|
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
|
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||||
|
mask[
|
||||||
|
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||||
|
] = True
|
||||||
|
|
||||||
|
mask_name = context.tensors.save(mask)
|
||||||
|
return MaskOutput(
|
||||||
|
mask=MaskField(mask_name=mask_name),
|
||||||
|
width=self.width,
|
||||||
|
height=self.height,
|
||||||
|
)
|
@ -194,6 +194,12 @@ class BoardField(BaseModel):
|
|||||||
board_id: str = Field(description="The id of the board")
|
board_id: str = Field(description="The id of the board")
|
||||||
|
|
||||||
|
|
||||||
|
class MaskField(BaseModel):
|
||||||
|
"""A mask primitive field."""
|
||||||
|
|
||||||
|
mask_name: str = Field(description="The name of the mask.")
|
||||||
|
|
||||||
|
|
||||||
class DenoiseMaskField(BaseModel):
|
class DenoiseMaskField(BaseModel):
|
||||||
"""An inpaint mask field"""
|
"""An inpaint mask field"""
|
||||||
|
|
||||||
@ -225,7 +231,12 @@ class ConditioningField(BaseModel):
|
|||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
# endregion
|
mask: Optional[MaskField] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
|
"included regions should be set to True.",
|
||||||
|
)
|
||||||
|
mask_weight: float = Field(description="")
|
||||||
|
|
||||||
|
|
||||||
class MetadataField(RootModel):
|
class MetadataField(RootModel):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
@ -9,6 +9,7 @@ import einops
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
@ -55,7 +56,14 @@ from invokeai.backend.lora import LoRAModelRaw
|
|||||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
Range,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
@ -65,7 +73,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
T2IAdapterData,
|
T2IAdapterData,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
@ -284,11 +291,11 @@ def get_scheduler(
|
|||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(
|
noise: Optional[LatentsField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -365,39 +372,190 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
self,
|
||||||
|
cond_list: list[ConditioningField],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.mask_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
mask = tf(mask)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def concat_regional_text_embeddings(
|
||||||
|
self,
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
conditioning_fields: list[ConditioningField],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
extra_conditioning = None
|
||||||
|
|
||||||
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
|
mask = masks[prompt_idx]
|
||||||
|
if (
|
||||||
|
text_embedding_info.extra_conditioning is not None
|
||||||
|
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
|
):
|
||||||
|
extra_conditioning = text_embedding_info.extra_conditioning
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||||
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||||
|
# mask, but we don't enforce this.
|
||||||
|
|
||||||
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||||
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||||
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
if pooled_embedding is None or mask is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None or mask is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
# embedding_ranges.append(
|
||||||
|
# Range(
|
||||||
|
# start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
# )
|
||||||
|
# )
|
||||||
|
# HACK(ryand): Contrary to its name, tokens_count_including_eos_bos does not seem to include eos and bos
|
||||||
|
# in the count.
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len + 1,
|
||||||
|
end=cur_text_embedding_len
|
||||||
|
+ text_embedding_info.extra_conditioning.tokens_count_including_eos_bos,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(
|
||||||
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
|
ranges=embedding_ranges,
|
||||||
|
mask_weights=[x.mask_weight for x in conditioning_fields],
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||||
|
"prompts."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
pooled_embeds=pooled_embedding,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
), regions
|
||||||
|
return BasicConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler: Scheduler,
|
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
seed: int,
|
latent_height: int,
|
||||||
) -> ConditioningData:
|
latent_width: int,
|
||||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
) -> TextConditioningData:
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||||
|
cond_list = self.positive_conditioning
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
uncond_list = self.negative_conditioning
|
||||||
|
if not isinstance(uncond_list, list):
|
||||||
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
cond_list, context, unet.device, unet.dtype
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
|
||||||
unconditioned_embeddings=uc,
|
|
||||||
text_embeddings=c,
|
|
||||||
guidance_scale=self.cfg_scale,
|
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
|
||||||
postprocessing_settings=PostprocessingSettings(
|
|
||||||
threshold=0.0, # threshold,
|
|
||||||
warmup=0.2, # warmup,
|
|
||||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
|
||||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
scheduler,
|
uncond_list, context, unet.device, unet.dtype
|
||||||
# for ddim scheduler
|
)
|
||||||
eta=0.0, # ddim_eta
|
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||||
# for ancestral and sde schedulers
|
text_conditionings=cond_text_embeddings,
|
||||||
# flip all bits to have noise different from initial
|
masks=cond_text_embedding_masks,
|
||||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
conditioning_fields=cond_list,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
conditioning_fields=uncond_list,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
conditioning_data = TextConditioningData(
|
||||||
|
uncond_text=uncond_text_embedding,
|
||||||
|
cond_text=cond_text_embedding,
|
||||||
|
uncond_regions=uncond_regions,
|
||||||
|
cond_regions=cond_regions,
|
||||||
|
guidance_scale=self.cfg_scale,
|
||||||
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
@ -503,7 +661,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||||
conditioning_data: ConditioningData,
|
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
) -> Optional[list[IPAdapterData]]:
|
) -> Optional[list[IPAdapterData]]:
|
||||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||||
@ -520,7 +677,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
ip_adapter_data_list = []
|
ip_adapter_data_list = []
|
||||||
conditioning_data.ip_adapter_conditioning = []
|
|
||||||
for single_ip_adapter in ip_adapter:
|
for single_ip_adapter in ip_adapter:
|
||||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||||
@ -543,16 +699,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
single_ipa_images, image_encoder_model
|
single_ipa_images, image_encoder_model
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data.ip_adapter_conditioning.append(
|
|
||||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_adapter_data_list.append(
|
ip_adapter_data_list.append(
|
||||||
IPAdapterData(
|
IPAdapterData(
|
||||||
ip_adapter_model=ip_adapter_model,
|
ip_adapter_model=ip_adapter_model,
|
||||||
weight=single_ip_adapter.weight,
|
weight=single_ip_adapter.weight,
|
||||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||||
end_step_percent=single_ip_adapter.end_step_percent,
|
end_step_percent=single_ip_adapter.end_step_percent,
|
||||||
|
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -642,6 +795,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
steps: int,
|
steps: int,
|
||||||
denoising_start: float,
|
denoising_start: float,
|
||||||
denoising_end: float,
|
denoising_end: float,
|
||||||
|
seed: int,
|
||||||
) -> Tuple[int, List[int], int]:
|
) -> Tuple[int, List[int], int]:
|
||||||
assert isinstance(scheduler, ConfigMixin)
|
assert isinstance(scheduler, ConfigMixin)
|
||||||
if scheduler.config.get("cpu_only", False):
|
if scheduler.config.get("cpu_only", False):
|
||||||
@ -670,7 +824,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||||
num_inference_steps = len(timesteps) // scheduler.order
|
num_inference_steps = len(timesteps) // scheduler.order
|
||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
scheduler_step_kwargs = {}
|
||||||
|
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||||
|
if "generator" in scheduler_step_signature.parameters:
|
||||||
|
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||||
|
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||||
|
# reproducibility.
|
||||||
|
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||||
|
|
||||||
|
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||||
|
|
||||||
def prep_inpaint_mask(
|
def prep_inpaint_mask(
|
||||||
self, context: InvocationContext, latents: torch.Tensor
|
self, context: InvocationContext, latents: torch.Tensor
|
||||||
@ -763,7 +925,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
context=context,
|
context=context,
|
||||||
@ -777,16 +942,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
ip_adapter_data = self.prep_ip_adapter_data(
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
context=context,
|
context=context,
|
||||||
ip_adapter=self.ip_adapter,
|
ip_adapter=self.ip_adapter,
|
||||||
conditioning_data=conditioning_data,
|
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
device=unet.device,
|
device=unet.device,
|
||||||
steps=self.steps,
|
steps=self.steps,
|
||||||
denoising_start=self.denoising_start,
|
denoising_start=self.denoising_start,
|
||||||
denoising_end=self.denoising_end,
|
denoising_end=self.denoising_end,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
result_latents = pipeline.latents_from_embeddings(
|
result_latents = pipeline.latents_from_embeddings(
|
||||||
@ -799,6 +964,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
masked_latents=masked_latents,
|
masked_latents=masked_latents,
|
||||||
gradient_mask=gradient_mask,
|
gradient_mask=gradient_mask,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data,
|
control_data=controlnet_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
|
@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
LatentsField,
|
LatentsField,
|
||||||
|
MaskField,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
@ -229,6 +230,18 @@ class StringCollectionInvocation(BaseInvocation):
|
|||||||
# region Image
|
# region Image
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("mask_output")
|
||||||
|
class MaskOutput(BaseInvocationOutput):
|
||||||
|
"""A torch mask tensor.
|
||||||
|
dtype: torch.bool
|
||||||
|
shape: (1, height, width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask: MaskField = OutputField(description="The mask.")
|
||||||
|
width: int = OutputField(description="The width of the mask in pixels.")
|
||||||
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_output")
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
@ -414,10 +427,6 @@ class ConditioningOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
|
||||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("conditioning_collection_output")
|
@invocation_output("conditioning_collection_output")
|
||||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||||
|
@ -1,182 +0,0 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
||||||
# and modified as needed
|
|
||||||
|
|
||||||
# tencent-ailab comment:
|
|
||||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
|
||||||
|
|
||||||
|
|
||||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
|
||||||
# loading.
|
|
||||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
|
||||||
def __init__(self):
|
|
||||||
DiffusersAttnProcessor2_0.__init__(self)
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
|
||||||
ip_adapter_image_prompt_embeds parameter.
|
|
||||||
"""
|
|
||||||
return DiffusersAttnProcessor2_0.__call__(
|
|
||||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor2_0(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
|
||||||
Args:
|
|
||||||
hidden_size (`int`):
|
|
||||||
The hidden size of the attention layer.
|
|
||||||
cross_attention_dim (`int`):
|
|
||||||
The number of channels in the `encoder_hidden_states`.
|
|
||||||
scale (`float`, defaults to 1.0):
|
|
||||||
the weight scale of image prompt.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
assert len(weights) == len(scales)
|
|
||||||
|
|
||||||
self._weights = weights
|
|
||||||
self._scales = scales
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
ip_adapter_image_prompt_embeds=None,
|
|
||||||
):
|
|
||||||
"""Apply IP-Adapter attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
|
||||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
|
||||||
"""
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
if encoder_hidden_states is not None:
|
|
||||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
|
||||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
|
||||||
assert ip_adapter_image_prompt_embeds is not None
|
|
||||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
|
||||||
|
|
||||||
for ipa_embed, ipa_weights, scale in zip(
|
|
||||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
|
||||||
):
|
|
||||||
# The batch dimensions should match.
|
|
||||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
|
||||||
# The token_len dimensions should match.
|
|
||||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
|
||||||
|
|
||||||
ip_hidden_states = ipa_embed
|
|
||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
|
||||||
|
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
|
||||||
|
|
||||||
hidden_states = hidden_states + scale * ip_hidden_states
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
@ -23,9 +23,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
IPAdapterConditioningInfo,
|
||||||
|
TextConditioningData,
|
||||||
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
|
|
||||||
@ -170,10 +173,11 @@ class ControlNetData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterData:
|
class IPAdapterData:
|
||||||
ip_adapter_model: IPAdapter = Field(default=None)
|
ip_adapter_model: IPAdapter
|
||||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||||
|
|
||||||
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||||
weight: Union[float, List[float]] = Field(default=1.0)
|
weight: Union[float, List[float]] = Field(default=1.0)
|
||||||
# weight: float = Field(default=1.0)
|
|
||||||
begin_step_percent: float = Field(default=0.0)
|
begin_step_percent: float = Field(default=0.0)
|
||||||
end_step_percent: float = Field(default=1.0)
|
end_step_percent: float = Field(default=1.0)
|
||||||
|
|
||||||
@ -314,7 +318,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
num_inference_steps: int,
|
num_inference_steps: int,
|
||||||
conditioning_data: ConditioningData,
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
|
conditioning_data: TextConditioningData,
|
||||||
*,
|
*,
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -374,6 +379,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
latents,
|
latents,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data,
|
conditioning_data,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
@ -393,7 +399,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
timesteps,
|
timesteps,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
@ -410,22 +417,35 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
use_cross_attention_control = (
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
|
use_regional_prompting = (
|
||||||
|
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_ip_adapter:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
if use_cross_attention_control and use_regional_prompting:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_attention_patcher = None
|
||||||
|
self.use_ip_adapter = use_ip_adapter
|
||||||
|
attn_ctx = nullcontext()
|
||||||
|
if use_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
if use_ip_adapter or use_regional_prompting:
|
||||||
elif ip_adapter_data is not None:
|
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
|
||||||
self.use_ip_adapter = True
|
|
||||||
else:
|
|
||||||
attn_ctx = nullcontext()
|
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -448,22 +468,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data,
|
conditioning_data,
|
||||||
step_index=i,
|
step_index=i,
|
||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
additional_guidance=additional_guidance,
|
additional_guidance=additional_guidance,
|
||||||
control_data=control_data,
|
control_data=control_data,
|
||||||
ip_adapter_data=ip_adapter_data,
|
ip_adapter_data=ip_adapter_data,
|
||||||
t2i_adapter_data=t2i_adapter_data,
|
t2i_adapter_data=t2i_adapter_data,
|
||||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
unet_attention_patcher=unet_attention_patcher,
|
||||||
)
|
)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
|
||||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
|
||||||
latents=latents,
|
|
||||||
sigma=batched_t,
|
|
||||||
step_index=i,
|
|
||||||
total_step_count=len(timesteps),
|
|
||||||
)
|
|
||||||
|
|
||||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
@ -485,14 +497,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
self,
|
self,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||||
):
|
):
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
@ -515,10 +528,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
unet_attention_patcher.set_scale(i, weight)
|
||||||
else:
|
else:
|
||||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
unet_attention_patcher.set_scale(i, 0.0)
|
||||||
|
|
||||||
# Handle ControlNet(s)
|
# Handle ControlNet(s)
|
||||||
down_block_additional_residuals = None
|
down_block_additional_residuals = None
|
||||||
@ -562,12 +575,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
down_intrablock_additional_residuals = accum_adapter_state
|
down_intrablock_additional_residuals = accum_adapter_state
|
||||||
|
|
||||||
|
ip_adapter_conditioning = None
|
||||||
|
if ip_adapter_data is not None:
|
||||||
|
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||||
|
|
||||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||||
sample=latent_model_input,
|
sample=latent_model_input,
|
||||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||||
@ -587,7 +605,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||||
|
|
||||||
# TODO: issue to diffusers?
|
# TODO: issue to diffusers?
|
||||||
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
import dataclasses
|
from dataclasses import dataclass
|
||||||
import inspect
|
from typing import List, Optional, Union
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -10,6 +8,11 @@ from .cross_attention_control import Arguments
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
|
"""Extra conditioning information produced by Compel.
|
||||||
|
|
||||||
|
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||||
|
"""
|
||||||
|
|
||||||
tokens_count_including_eos_bos: int
|
tokens_count_including_eos_bos: int
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
@ -20,6 +23,8 @@ class ExtraConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
|
"""SD 1/2 text conditioning information produced by Compel."""
|
||||||
|
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
|
|
||||||
@ -35,6 +40,8 @@ class ConditioningFieldData:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
"""SDXL text conditioning information produced by Compel."""
|
||||||
|
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
@ -44,14 +51,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
return super().to(device=device, dtype=dtype)
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PostprocessingSettings:
|
|
||||||
threshold: float
|
|
||||||
warmup: float
|
|
||||||
h_symmetry_time_pct: Optional[float]
|
|
||||||
v_symmetry_time_pct: Optional[float]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterConditioningInfo:
|
class IPAdapterConditioningInfo:
|
||||||
cond_image_prompt_embeds: torch.Tensor
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
@ -65,41 +64,55 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class Range:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
start: int
|
||||||
text_embeddings: BasicConditioningInfo
|
end: int
|
||||||
"""
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
||||||
"""
|
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
|
||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
|
||||||
"""
|
|
||||||
guidance_rescale_multiplier: float = 0
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
"""
|
|
||||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
|
||||||
"""
|
|
||||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
|
||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
|
||||||
|
|
||||||
@property
|
class TextConditioningRegions:
|
||||||
def dtype(self):
|
def __init__(
|
||||||
return self.text_embeddings.dtype
|
self,
|
||||||
|
masks: torch.Tensor,
|
||||||
|
ranges: list[Range],
|
||||||
|
mask_weights: list[float],
|
||||||
|
):
|
||||||
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
|
# Shape: (1, num_prompts, height, width)
|
||||||
|
# Dtype: torch.bool
|
||||||
|
self.masks = masks
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||||
scheduler_args = dict(self.scheduler_args)
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
step_method = inspect.signature(scheduler.step)
|
self.ranges = ranges
|
||||||
for name, value in kwargs.items():
|
|
||||||
try:
|
self.mask_weights = mask_weights
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
assert self.masks.shape[1] == len(self.ranges) == len(self.mask_weights)
|
||||||
# FIXME: don't silently discard arguments
|
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
||||||
else:
|
class TextConditioningData:
|
||||||
scheduler_args[name] = value
|
def __init__(
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
self,
|
||||||
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
|
guidance_scale: Union[float, List[float]],
|
||||||
|
guidance_rescale_multiplier: float = 0,
|
||||||
|
):
|
||||||
|
self.uncond_text = uncond_text
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.uncond_regions = uncond_regions
|
||||||
|
self.cond_regions = cond_regions
|
||||||
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
|
def is_sdxl(self):
|
||||||
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
242
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
242
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||||
|
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||||
|
|
||||||
|
This implementation is based on
|
||||||
|
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||||
|
|
||||||
|
Supported custom features:
|
||||||
|
- IP-Adapter
|
||||||
|
- Regional prompt attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||||
|
ip_adapter_scales: Optional[list[float]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
|
|
||||||
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
|
layer-specific are passed to __init__().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||||
|
for the i'th IP-Adapter.
|
||||||
|
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||||
|
the i'th IP-Adapter.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._ip_adapter_weights = ip_adapter_weights
|
||||||
|
self._ip_adapter_scales = ip_adapter_scales
|
||||||
|
|
||||||
|
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||||
|
if self._ip_adapter_weights is not None:
|
||||||
|
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||||
|
|
||||||
|
def _is_ip_adapter_enabled(self) -> bool:
|
||||||
|
return self._ip_adapter_weights is not None
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
# For regional prompting:
|
||||||
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
|
percent_through: Optional[float] = None,
|
||||||
|
# For IP-Adapter:
|
||||||
|
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""Apply attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||||
|
apply regional prompt masking.
|
||||||
|
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||||
|
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||||
|
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||||
|
"""
|
||||||
|
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
residual = hidden_states
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Handle regional prompt attention masks.
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
assert percent_through is not None
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
|
if is_cross_attention:
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
|
)
|
||||||
|
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||||
|
dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_mask_weight = 1.0 * ((1 - percent_through) ** 5)
|
||||||
|
else: # self-attention
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
|
||||||
|
query_seq_len=query_seq_len,
|
||||||
|
percent_through=percent_through,
|
||||||
|
)
|
||||||
|
attn_mask_weight = 0.3 * ((1 - percent_through) ** 5)
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
args = () if USE_PEFT_BACKEND else (scale,)
|
||||||
|
query = attn.to_q(hidden_states, *args)
|
||||||
|
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
elif attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
|
key = attn.to_k(encoder_hidden_states, *args)
|
||||||
|
value = attn.to_v(encoder_hidden_states, *args)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if regional_prompt_data is not None and percent_through < 0.3:
|
||||||
|
# Don't apply to uncond????
|
||||||
|
|
||||||
|
prompt_region_attention_mask = attn.prepare_attention_mask(
|
||||||
|
prompt_region_attention_mask, sequence_length, batch_size
|
||||||
|
)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.view(
|
||||||
|
batch_size, attn.heads, -1, prompt_region_attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
scale_factor = 1 / math.sqrt(query.size(-1))
|
||||||
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||||
|
m_pos = attn_weight.max(dim=-1, keepdim=True)[0] - attn_weight
|
||||||
|
m_neg = attn_weight - attn_weight.min(dim=-1, keepdim=True)[0]
|
||||||
|
|
||||||
|
prompt_region_attention_mask = attn_mask_weight * (
|
||||||
|
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = prompt_region_attention_mask
|
||||||
|
else:
|
||||||
|
attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
|
# Apply IP-Adapter conditioning.
|
||||||
|
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||||
|
if self._is_ip_adapter_enabled():
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
for ipa_embed, ipa_weights, scale in zip(
|
||||||
|
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||||
|
):
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The token_len dimensions should match.
|
||||||
|
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||||
|
|
||||||
|
ip_hidden_states = ipa_embed
|
||||||
|
|
||||||
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
else:
|
||||||
|
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||||
|
assert ip_adapter_image_prompt_embeds is None
|
||||||
|
|
||||||
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
@ -0,0 +1,164 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
TextConditioningRegions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegionalPromptData:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
regions: list[TextConditioningRegions],
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
max_downscale_factor: int = 8,
|
||||||
|
):
|
||||||
|
"""Initialize a `RegionalPromptData` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
|
batch.
|
||||||
|
device (torch.device): The device to use for the attention masks.
|
||||||
|
dtype (torch.dtype): The data type to use for the attention masks.
|
||||||
|
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||||
|
in steps of 2x.
|
||||||
|
"""
|
||||||
|
self._regions = regions
|
||||||
|
self._device = device
|
||||||
|
self._dtype = dtype
|
||||||
|
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||||
|
# sequence length of s.
|
||||||
|
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||||
|
regions, max_downscale_factor
|
||||||
|
)
|
||||||
|
self._negative_cross_attn_mask_score = 0.0
|
||||||
|
self._size_weight = 1.0
|
||||||
|
|
||||||
|
def _prepare_spatial_masks(
|
||||||
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
|
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||||
|
# of s.
|
||||||
|
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
|
||||||
|
for batch_sample_regions in regions:
|
||||||
|
batch_sample_masks_by_seq_len.append({})
|
||||||
|
|
||||||
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
|
batch_sample_masks = batch_sample_regions.masks.to(device=self._device, dtype=self._dtype)
|
||||||
|
|
||||||
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
b, _num_prompts, h, w = batch_sample_masks.shape
|
||||||
|
assert b == 1
|
||||||
|
query_seq_len = h * w
|
||||||
|
|
||||||
|
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||||
|
|
||||||
|
downscale_factor *= 2
|
||||||
|
if downscale_factor <= max_downscale_factor:
|
||||||
|
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||||
|
# regions to be lost entirely.
|
||||||
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||||
|
# potentially use a weighted mask rather than a binary mask.
|
||||||
|
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return batch_sample_masks_by_seq_len
|
||||||
|
|
||||||
|
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the cross-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||||
|
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, key_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
batch_sample_query_scores = batch_sample_query_masks[0, prompt_idx, :, :]
|
||||||
|
size = batch_sample_query_scores.sum() / batch_sample_query_scores.numel()
|
||||||
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
|
# size = size.to(dtype=batch_sample_query_scores.dtype)
|
||||||
|
# batch_sample_query_mask = batch_sample_query_scores > 0.5
|
||||||
|
# batch_sample_query_scores[batch_sample_query_mask] = 1.0 * (1.0 - size)
|
||||||
|
# batch_sample_query_scores[~batch_sample_query_mask] = 0.0
|
||||||
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_scores * (
|
||||||
|
mask_weight + self._size_weight * (1 - size)
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
def get_self_attn_mask(self, query_seq_len: int, percent_through: float) -> torch.Tensor:
|
||||||
|
"""Get the self-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, query_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len), dtype=self._dtype, device=self._device)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx in range(num_prompts):
|
||||||
|
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||||
|
size = prompt_query_mask.sum() / prompt_query_mask.numel()
|
||||||
|
size = size.to(dtype=prompt_query_mask.dtype)
|
||||||
|
mask_weight = batch_sample_regions.mask_weights[prompt_idx]
|
||||||
|
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||||
|
# query_seq_len) mask.
|
||||||
|
# TODO(ryand): Is += really the best option here? Maybe elementwise max is better?
|
||||||
|
attn_mask[batch_idx, :, :] = torch.maximum(
|
||||||
|
attn_mask[batch_idx, :, :],
|
||||||
|
prompt_query_mask.unsqueeze(0)
|
||||||
|
* prompt_query_mask.unsqueeze(1)
|
||||||
|
* (mask_weight + self._size_weight * (1 - size)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# if attn_mask[batch_idx].max() < 0.01:
|
||||||
|
# attn_mask[batch_idx, ...] = 1.0
|
||||||
|
|
||||||
|
# attn_mask[attn_mask > 0.5] = 1.0
|
||||||
|
# attn_mask[attn_mask <= 0.5] = 0.0
|
||||||
|
# attn_mask_min = attn_mask[batch_idx].min()
|
||||||
|
|
||||||
|
# # Adjust so that the minimum value is 0.0 regardless of whether all pixels are covered or not.
|
||||||
|
# if abs(attn_mask_min) > 0.0001:
|
||||||
|
# attn_mask[batch_idx] = attn_mask[batch_idx] - attn_mask_min
|
||||||
|
return attn_mask
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
@ -10,11 +11,13 @@ from typing_extensions import TypeAlias
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ConditioningData,
|
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
PostprocessingSettings,
|
IPAdapterConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
Range,
|
||||||
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -56,7 +59,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||||
"""
|
"""
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = InvokeAIAppConfig.get_config()
|
||||||
self.conditioning = None
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model_forward_callback = model_forward_callback
|
self.model_forward_callback = model_forward_callback
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
@ -91,7 +93,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
conditioning_data,
|
conditioning_data: TextConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
@ -124,38 +126,30 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
|
||||||
],
|
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(
|
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||||
encoder_hidden_states,
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask,
|
|
||||||
) = self._concat_conditionings_for_batch(
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
|
||||||
conditioning_data.text_embeddings.embeds,
|
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -199,16 +193,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
):
|
||||||
|
percent_through = step_index / total_step_count
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
cross_attention_control_types_to_do = (
|
cross_attention_control_types_to_do = (
|
||||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
)
|
)
|
||||||
@ -224,6 +219,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
|
percent_through=percent_through,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
@ -237,6 +234,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
|
percent_through=percent_through,
|
||||||
|
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
@ -244,19 +243,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
return unconditioned_next_x, conditioned_next_x
|
||||||
|
|
||||||
def do_latent_postprocessing(
|
|
||||||
self,
|
|
||||||
postprocessing_settings: PostprocessingSettings,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
sigma,
|
|
||||||
step_index,
|
|
||||||
total_step_count,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if postprocessing_settings is not None:
|
|
||||||
percent_through = step_index / total_step_count
|
|
||||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
|
||||||
return latents
|
|
||||||
|
|
||||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||||
conditioning_attention_mask = torch.ones(
|
conditioning_attention_mask = torch.ones(
|
||||||
@ -304,13 +290,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
|
percent_through: float,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -321,41 +307,55 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||||
torch.stack(
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
]
|
||||||
)
|
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
uncond_text = conditioning_data.uncond_text
|
||||||
]
|
cond_text = conditioning_data.cond_text
|
||||||
}
|
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||||
[
|
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
"time_ids": torch.cat(
|
|
||||||
[
|
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
uncond_text.embeds, cond_text.embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||||
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
mask_weights=[0.0],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=regions, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
time.sleep(1.0)
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
@ -374,8 +374,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||||
|
percent_through: float,
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -422,36 +424,40 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Unconditioned pass
|
# Unconditioned pass
|
||||||
#####################
|
#####################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
]
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
if conditioning_data.is_sdxl():
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the unconditioned pass.
|
||||||
|
if conditioning_data.uncond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
@ -463,36 +469,41 @@ class InvokeAIDiffuserComponent:
|
|||||||
# Conditioned pass
|
# Conditioned pass
|
||||||
###################
|
###################
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = {}
|
||||||
|
|
||||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||||
"ip_adapter_image_prompt_embeds": [
|
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
]
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||||
if cross_attn_processor_context is not None:
|
if cross_attn_processor_context is not None:
|
||||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||||
cross_attention_kwargs = {"swap_cross_attn_context": cross_attn_processor_context}
|
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Prepare prompt regions for the conditioned pass.
|
||||||
|
if conditioning_data.cond_regions is not None:
|
||||||
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
|
regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
cross_attention_kwargs["percent_through"] = percent_through
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
@ -506,64 +517,3 @@ class InvokeAIDiffuserComponent:
|
|||||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
def apply_symmetry(
|
|
||||||
self,
|
|
||||||
postprocessing_settings: PostprocessingSettings,
|
|
||||||
latents: torch.Tensor,
|
|
||||||
percent_through: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# Reset our last percent through if this is our first step.
|
|
||||||
if percent_through == 0.0:
|
|
||||||
self.last_percent_through = 0.0
|
|
||||||
|
|
||||||
if postprocessing_settings is None:
|
|
||||||
return latents
|
|
||||||
|
|
||||||
# Check for out of bounds
|
|
||||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
|
||||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
|
||||||
h_symmetry_time_pct = None
|
|
||||||
|
|
||||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
|
||||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
|
||||||
v_symmetry_time_pct = None
|
|
||||||
|
|
||||||
dev = latents.device.type
|
|
||||||
|
|
||||||
latents.to(device="cpu")
|
|
||||||
|
|
||||||
if (
|
|
||||||
h_symmetry_time_pct is not None
|
|
||||||
and self.last_percent_through < h_symmetry_time_pct
|
|
||||||
and percent_through >= h_symmetry_time_pct
|
|
||||||
):
|
|
||||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
|
||||||
width = latents.shape[3]
|
|
||||||
x_flipped = torch.flip(latents, dims=[3])
|
|
||||||
latents = torch.cat(
|
|
||||||
[
|
|
||||||
latents[:, :, :, 0 : int(width / 2)],
|
|
||||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
|
||||||
],
|
|
||||||
dim=3,
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
v_symmetry_time_pct is not None
|
|
||||||
and self.last_percent_through < v_symmetry_time_pct
|
|
||||||
and percent_through >= v_symmetry_time_pct
|
|
||||||
):
|
|
||||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
|
||||||
height = latents.shape[2]
|
|
||||||
y_flipped = torch.flip(latents, dims=[2])
|
|
||||||
latents = torch.cat(
|
|
||||||
[
|
|
||||||
latents[:, :, 0 : int(height / 2)],
|
|
||||||
y_flipped[:, :, int(height / 2) : int(height)],
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.last_percent_through = percent_through
|
|
||||||
return latents.to(device=dev)
|
|
||||||
|
@ -1,52 +1,55 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
|
||||||
|
|
||||||
|
|
||||||
class UNetPatcher:
|
class UNetAttentionPatcher:
|
||||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||||
self._ip_adapters = ip_adapters
|
self._ip_adapters = ip_adapters
|
||||||
self._scales = [1.0] * len(self._ip_adapters)
|
self._ip_adapter_scales = None
|
||||||
|
|
||||||
|
if self._ip_adapters is not None:
|
||||||
|
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||||
|
|
||||||
def set_scale(self, idx: int, value: float):
|
def set_scale(self, idx: int, value: float):
|
||||||
self._scales[idx] = value
|
self._ip_adapter_scales[idx] = value
|
||||||
|
|
||||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||||
weights into them.
|
weights into them (if IP-Adapters are being applied).
|
||||||
|
|
||||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||||
"""
|
"""
|
||||||
# Construct a dict of attention processors based on the UNet's architecture.
|
# Construct a dict of attention processors based on the UNet's architecture.
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||||
if name.endswith("attn1.processor"):
|
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||||
attn_procs[name] = AttnProcessor2_0()
|
# "attn1" processors do not use IP-Adapters.
|
||||||
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
attn_procs[name] = IPAttnProcessor2_0(
|
attn_procs[name] = CustomAttnProcessor2_0(
|
||||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||||
self._scales,
|
self._ip_adapter_scales,
|
||||||
)
|
)
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||||
|
|
||||||
attn_procs = self._prepare_attention_processors(unet)
|
attn_procs = self._prepare_attention_processors(unet)
|
||||||
|
|
||||||
orig_attn_processors = unet.attn_processors
|
orig_attn_processors = unet.attn_processors
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||||
unet.set_attn_processor(attn_procs)
|
unet.set_attn_processor(attn_procs)
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
@ -1,8 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
|
||||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||||
from invokeai.backend.util.test_utils import install_and_load_model
|
from invokeai.backend.util.test_utils import install_and_load_model
|
||||||
|
|
||||||
|
|
||||||
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
|||||||
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
||||||
|
|
||||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
|
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
|
||||||
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
||||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user