mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
39 Commits
lstein/doc
...
ryan/regio
Author | SHA1 | Date | |
---|---|---|---|
ffc4ebb14c | |||
5b3adf0740 | |||
a5c94fba43 | |||
3e14bd6c45 | |||
8721926f14 | |||
d87ff3a206 | |||
7d9671014b | |||
4a1acd4db9 | |||
8989a6cdc6 | |||
f44d3da9b1 | |||
1bbd4f751d | |||
bdf3691ad0 | |||
e7f7ae660d | |||
e132afb705 | |||
5f49e7ae26 | |||
53ebca58ff | |||
ee1b3157ce | |||
e7ec13f209 | |||
cad3e5dbd7 | |||
845c4e93ae | |||
54971afe44 | |||
cfba51aed5 | |||
2966c8de2c | |||
b0fcbe552e | |||
d132fb4818 | |||
2d5d370f38 | |||
878bbc3527 | |||
caa690e24d | |||
38248b988f | |||
ba4788007f | |||
ef51005881 | |||
7b0326d7f7 | |||
f590b39f88 | |||
58277c6ada | |||
382fa57f3b | |||
ee3abc171d | |||
bf72cee555 | |||
e866e3b19f | |||
16e574825c |
58
invokeai/app/invocations/conditioning.py
Normal file
58
invokeai/app/invocations/conditioning.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InvocationContext,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
"add_conditioning_mask",
|
||||
title="Add Conditioning Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class AddConditioningMaskInvocation(BaseInvocation):
|
||||
"""Add a mask to an existing conditioning tensor."""
|
||||
|
||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
self.conditioning.mask = self.mask
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@invocation(
|
||||
"rectangle_mask",
|
||||
title="Create Rectangle Mask",
|
||||
tags=["conditioning"],
|
||||
category="conditioning",
|
||||
version="1.0.0",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Create a rectangular mask."""
|
||||
|
||||
height: int = InputField(description="The height of the entire mask.")
|
||||
width: int = InputField(description="The width of the entire mask.")
|
||||
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||
mask[
|
||||
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||
] = True
|
||||
|
||||
mask_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
mask=MaskField(mask_name=mask_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
@ -194,6 +194,12 @@ class BoardField(BaseModel):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class MaskField(BaseModel):
|
||||
"""A mask primitive field."""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask.")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
@ -225,7 +231,11 @@ class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
# endregion
|
||||
mask: Optional[MaskField] = Field(
|
||||
default=None,
|
||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
from functools import singledispatchmethod
|
||||
@ -9,6 +9,7 @@ import einops
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
@ -50,12 +51,20 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
Range,
|
||||
SDXLConditioningInfo,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
@ -65,7 +74,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
T2IAdapterData,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
@ -270,11 +278,11 @@ def get_scheduler(
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
@ -351,41 +359,185 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_field: Union[ConditioningField, list[ConditioningField]],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||
# Normalize cond_field to a list.
|
||||
cond_list = cond_field
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.mask_name)
|
||||
text_embeddings_masks.append(mask)
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
extra_conditioning = None
|
||||
|
||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||
if (
|
||||
text_embedding_info.extra_conditioning is not None
|
||||
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
):
|
||||
extra_conditioning = text_embedding_info.extra_conditioning
|
||||
|
||||
if is_sdxl:
|
||||
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||
# mask, but we don't enforce this.
|
||||
|
||||
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
if pooled_embedding is None or mask is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None or mask is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
)
|
||||
)
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regions = None
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
||||
|
||||
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||
"prompts."
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=extra_conditioning,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
), regions
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=extra_conditioning,
|
||||
), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler: Scheduler,
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
unet,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> TextConditioningData:
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
self.positive_conditioning, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
self.negative_conditioning, context, unet.device, unet.dtype
|
||||
)
|
||||
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
warmup=0.2, # warmup,
|
||||
h_symmetry_time_pct=None, # h_symmetry_time_pct,
|
||||
v_symmetry_time_pct=None, # v_symmetry_time_pct,
|
||||
),
|
||||
)
|
||||
|
||||
conditioning_data = conditioning_data.add_scheduler_args_if_applicable( # FIXME
|
||||
scheduler,
|
||||
# for ddim scheduler
|
||||
eta=0.0, # ddim_eta
|
||||
# for ancestral and sde schedulers
|
||||
# flip all bits to have noise different from initial
|
||||
generator=torch.Generator(device=unet.device).manual_seed(seed ^ 0xFFFFFFFF),
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@ -491,7 +643,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
conditioning_data: ConditioningData,
|
||||
exit_stack: ExitStack,
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
@ -508,7 +659,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return None
|
||||
|
||||
ip_adapter_data_list = []
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
@ -531,16 +681,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
|
||||
conditioning_data.ip_adapter_conditioning.append(
|
||||
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
|
||||
)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
IPAdapterData(
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=single_ip_adapter.weight,
|
||||
begin_step_percent=single_ip_adapter.begin_step_percent,
|
||||
end_step_percent=single_ip_adapter.end_step_percent,
|
||||
ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds),
|
||||
)
|
||||
)
|
||||
|
||||
@ -630,6 +777,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[int, List[int], int]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
@ -658,7 +806,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
scheduler_step_kwargs = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
if "generator" in scheduler_step_signature.parameters:
|
||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||
# reproducibility.
|
||||
scheduler_step_kwargs = {"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)}
|
||||
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@ -751,7 +907,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
@ -765,22 +924,19 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
(
|
||||
result_latents,
|
||||
result_attention_map_saver,
|
||||
) = pipeline.latents_from_embeddings(
|
||||
result_latents = pipeline.latents_from_embeddings(
|
||||
latents=latents,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
@ -790,6 +946,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
|
@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
@ -229,6 +230,18 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
# region Image
|
||||
|
||||
|
||||
@invocation_output("mask_output")
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor.
|
||||
dtype: torch.bool
|
||||
shape: (1, height, width).
|
||||
"""
|
||||
|
||||
mask: MaskField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
@ -1,182 +0,0 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
# tencent-ailab comment:
|
||||
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
|
||||
|
||||
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
|
||||
# loading.
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
|
||||
|
||||
class IPAttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapater for PyTorch 2.0.
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
assert len(weights) == len(scales)
|
||||
|
||||
self._weights = weights
|
||||
self._scales = scales
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Apply IP-Adapter attention.
|
||||
|
||||
Args:
|
||||
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
|
||||
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
|
||||
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._weights, self._scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
@ -4,13 +4,11 @@ Initialization file for the invokeai.backend.stable_diffusion package
|
||||
|
||||
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .seamless import set_seamless # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"PipelineIntermediateState",
|
||||
"StableDiffusionGeneratorPipeline",
|
||||
"InvokeAIDiffuserComponent",
|
||||
"AttentionMapSaver",
|
||||
"set_seamless",
|
||||
]
|
||||
|
@ -12,7 +12,6 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@ -24,11 +23,14 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterConditioningInfo,
|
||||
TextConditioningData,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -39,7 +41,6 @@ class PipelineIntermediateState:
|
||||
timestep: int
|
||||
latents: torch.Tensor
|
||||
predicted_original: Optional[torch.Tensor] = None
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -172,10 +173,11 @@ class ControlNetData:
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
ip_adapter_model: IPAdapter
|
||||
ip_adapter_conditioning: IPAdapterConditioningInfo
|
||||
|
||||
# Either a single weight applied to all steps, or a list of weights for each step.
|
||||
weight: Union[float, List[float]] = Field(default=1.0)
|
||||
# weight: float = Field(default=1.0)
|
||||
begin_step_percent: float = Field(default=0.0)
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
@ -190,19 +192,6 @@ class T2IAdapterData:
|
||||
end_step_percent: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
Output class for InvokeAI's Stable Diffusion pipeline.
|
||||
|
||||
Args:
|
||||
attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
|
||||
after generation completes. Optional.
|
||||
"""
|
||||
|
||||
attention_map_saver: Optional[AttentionMapSaver]
|
||||
|
||||
|
||||
class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
@ -329,7 +318,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
conditioning_data: ConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
@ -343,9 +333,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||
) -> torch.Tensor:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents, None
|
||||
return latents
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
@ -385,10 +375,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
try:
|
||||
latents, attention_map_saver = self.generate_latents_from_embeddings(
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
@ -402,46 +393,59 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if mask is not None and not gradient_mask:
|
||||
latents = torch.lerp(orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype))
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
attention_map_saver: Optional[AttentionMapSaver] = None
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
ip_adapter_unet_patcher = None
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||
use_cross_attention_control = (
|
||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
if use_cross_attention_control and use_ip_adapter:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (`.swap()`) and IP-Adapter cannot be used simultaneously."
|
||||
)
|
||||
if use_cross_attention_control and use_regional_prompting:
|
||||
raise ValueError(
|
||||
"Prompt-to-prompt cross-attention control (`.swap()`) and regional prompting cannot be used simultaneously."
|
||||
)
|
||||
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
if use_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=conditioning_data.extra,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
)
|
||||
self.use_ip_adapter = False
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
self.use_ip_adapter = True
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
ip_adapters = [ipa.ip_adapter_model for ipa in ip_adapter_data] if use_ip_adapter else None
|
||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
@ -464,31 +468,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
|
||||
unet_attention_patcher=unet_attention_patcher,
|
||||
)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
# if i == len(timesteps)-1 and extra_conditioning_info is not None:
|
||||
# eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
|
||||
# attention_map_token_ids = range(1, eos_token_index)
|
||||
# attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
|
||||
# self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)
|
||||
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
@ -498,25 +487,25 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
attention_map_saver=attention_map_saver,
|
||||
)
|
||||
)
|
||||
|
||||
return latents, attention_map_saver
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
self,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
conditioning_data: ConditioningData,
|
||||
conditioning_data: TextConditioningData,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
|
||||
unet_attention_patcher: Optional[UNetAttentionPatcher] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
@ -539,20 +528,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
||||
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
||||
ip_adapter_unet_patcher.set_scale(i, weight)
|
||||
unet_attention_patcher.set_scale(i, weight)
|
||||
else:
|
||||
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
||||
ip_adapter_unet_patcher.set_scale(i, 0.0)
|
||||
unet_attention_patcher.set_scale(i, 0.0)
|
||||
|
||||
# Handle ControlNet(s) and T2I-Adapter(s)
|
||||
# Handle ControlNet(s)
|
||||
down_block_additional_residuals = None
|
||||
mid_block_additional_residual = None
|
||||
down_intrablock_additional_residuals = None
|
||||
# if control_data is not None and t2i_adapter_data is not None:
|
||||
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
|
||||
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
|
||||
# raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
|
||||
# elif control_data is not None:
|
||||
if control_data is not None:
|
||||
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
|
||||
control_data=control_data,
|
||||
@ -562,7 +545,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
)
|
||||
# elif t2i_adapter_data is not None:
|
||||
|
||||
# Handle T2I-Adapter(s)
|
||||
down_intrablock_additional_residuals = None
|
||||
if t2i_adapter_data is not None:
|
||||
accum_adapter_state = None
|
||||
for single_t2i_adapter_data in t2i_adapter_data:
|
||||
@ -588,16 +573,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
|
||||
accum_adapter_state[idx] += value * t2i_adapter_weight
|
||||
|
||||
# down_block_additional_residuals = accum_adapter_state
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
ip_adapter_conditioning = None
|
||||
if ip_adapter_data is not None:
|
||||
ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data]
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
conditioning_data=conditioning_data,
|
||||
# extra:
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals, # for ControlNet
|
||||
mid_block_additional_residual=mid_block_additional_residual, # for ControlNet
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter
|
||||
@ -617,7 +605,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: issue to diffusers?
|
||||
# undo internal counter increment done by scheduler.step, so timestep can be resolved as before call
|
||||
|
@ -2,6 +2,4 @@
|
||||
Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
@ -1,7 +1,5 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -10,6 +8,11 @@ from .cross_attention_control import Arguments
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
"""Extra conditioning information produced by Compel.
|
||||
|
||||
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||
"""
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@ -20,12 +23,10 @@ class ExtraConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||
# should only be stored in one place.
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
@ -39,6 +40,8 @@ class ConditioningFieldData:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
@ -48,14 +51,6 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
@ -69,42 +64,48 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
class TextConditioningRegions:
|
||||
def __init__(self, masks: torch.Tensor, ranges: list[Range]):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
self.masks = masks
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
@ -3,19 +3,13 @@
|
||||
|
||||
|
||||
import enum
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, SlicedAttnProcessor
|
||||
from diffusers.models.attention_processor import Attention, SlicedAttnProcessor
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...util import torch_dtype
|
||||
|
||||
@ -25,72 +19,14 @@ class CrossAttentionType(enum.Enum):
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
class Context:
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
cross_attention_index_map: Optional[torch.Tensor]
|
||||
|
||||
class Action(enum.Enum):
|
||||
NONE = 0
|
||||
SAVE = (1,)
|
||||
APPLY = 2
|
||||
|
||||
def __init__(self, arguments: Arguments, step_count: int):
|
||||
class CrossAttnControlContext:
|
||||
def __init__(self, arguments: Arguments):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.cross_attention_mask = None
|
||||
self.cross_attention_index_map = None
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.cross_attention_mask: Optional[torch.Tensor] = None
|
||||
self.cross_attention_index_map: Optional[torch.Tensor] = None
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
self.self_cross_attention_module_identifiers = []
|
||||
self.tokens_cross_attention_module_identifiers = []
|
||||
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
self.clear_requests(cleanup=True)
|
||||
|
||||
def register_cross_attention_modules(self, model):
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.SELF):
|
||||
if name in self.self_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.self_cross_attention_module_identifiers.append(name)
|
||||
for name, _module in get_cross_attention_modules(model, CrossAttentionType.TOKENS):
|
||||
if name in self.tokens_cross_attention_module_identifiers:
|
||||
raise AssertionError(f"name {name} cannot appear more than once")
|
||||
self.tokens_cross_attention_module_identifiers.append(name)
|
||||
|
||||
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.SAVE
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.SAVE
|
||||
|
||||
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
||||
if cross_attention_type == CrossAttentionType.SELF:
|
||||
self.self_cross_attention_action = Context.Action.APPLY
|
||||
else:
|
||||
self.tokens_cross_attention_action = Context.Action.APPLY
|
||||
|
||||
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
||||
return module_identifier in self.tokens_cross_attention_module_identifiers
|
||||
|
||||
def get_should_save_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.SAVE
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.SAVE
|
||||
return False
|
||||
|
||||
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
||||
if module_identifier in self.self_cross_attention_module_identifiers:
|
||||
return self.self_cross_attention_action == Context.Action.APPLY
|
||||
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
||||
return self.tokens_cross_attention_action == Context.Action.APPLY
|
||||
return False
|
||||
|
||||
def get_active_cross_attention_control_types_for_step(
|
||||
self, percent_through: float = None
|
||||
@ -111,219 +47,8 @@ class Context:
|
||||
to_control.append(CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
def save_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
slice: torch.Tensor,
|
||||
dim: Optional[int],
|
||||
offset: int,
|
||||
slice_size: Optional[int],
|
||||
):
|
||||
if identifier not in self.saved_cross_attention_maps:
|
||||
self.saved_cross_attention_maps[identifier] = {
|
||||
"dim": dim,
|
||||
"slice_size": slice_size,
|
||||
"slices": {offset or 0: slice},
|
||||
}
|
||||
else:
|
||||
self.saved_cross_attention_maps[identifier]["slices"][offset or 0] = slice
|
||||
|
||||
def get_slice(
|
||||
self,
|
||||
identifier: str,
|
||||
requested_dim: Optional[int],
|
||||
requested_offset: int,
|
||||
slice_size: int,
|
||||
):
|
||||
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
||||
if requested_dim is None:
|
||||
if saved_attention_dict["dim"] is not None:
|
||||
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
||||
return saved_attention_dict["slices"][0]
|
||||
|
||||
if saved_attention_dict["dim"] == requested_dim:
|
||||
if slice_size != saved_attention_dict["slice_size"]:
|
||||
raise RuntimeError(
|
||||
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}"
|
||||
)
|
||||
return saved_attention_dict["slices"][requested_offset]
|
||||
|
||||
if saved_attention_dict["dim"] is None:
|
||||
whole_saved_attention = saved_attention_dict["slices"][0]
|
||||
if requested_dim == 0:
|
||||
return whole_saved_attention[requested_offset : requested_offset + slice_size]
|
||||
elif requested_dim == 1:
|
||||
return whole_saved_attention[:, requested_offset : requested_offset + slice_size]
|
||||
|
||||
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
||||
|
||||
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
||||
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
||||
if saved_attention is None:
|
||||
return None, None
|
||||
return saved_attention["dim"], saved_attention["slice_size"]
|
||||
|
||||
def clear_requests(self, cleanup=True):
|
||||
self.tokens_cross_attention_action = Context.Action.NONE
|
||||
self.self_cross_attention_action = Context.Action.NONE
|
||||
if cleanup:
|
||||
self.saved_cross_attention_maps = {}
|
||||
|
||||
def offload_saved_attention_slices_to_cpu(self):
|
||||
for _key, map_dict in self.saved_cross_attention_maps.items():
|
||||
for offset, slice in map_dict["slices"].items():
|
||||
map_dict[offset] = slice.to("cpu")
|
||||
|
||||
|
||||
class InvokeAICrossAttentionMixin:
|
||||
"""
|
||||
Enable InvokeAI-flavoured Attention calculation, which does aggressive low-memory slicing and calls
|
||||
through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling
|
||||
and dymamic slicing strategy selection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
self.attention_slice_wrangler = None
|
||||
self.slicing_strategy_getter = None
|
||||
self.attention_slice_calculated_callback = None
|
||||
|
||||
def set_attention_slice_wrangler(
|
||||
self,
|
||||
wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
`module` is the current Attention module for which the callback is being invoked.
|
||||
`suggested_attention_slice` is the default-calculated attention slice
|
||||
`dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
"""
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int, int]]]):
|
||||
self.slicing_strategy_getter = getter
|
||||
|
||||
def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]):
|
||||
self.attention_slice_calculated_callback = callback
|
||||
|
||||
def einsum_lowest_level(self, query, key, value, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
# attention_scores = torch.einsum('b i d, b j d -> b i j', q, k)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query.shape[0],
|
||||
query.shape[1],
|
||||
key.shape[1],
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
# calculate attention slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
attention_slice_wrangler = self.attention_slice_wrangler
|
||||
if attention_slice_wrangler is not None:
|
||||
attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
if self.attention_slice_calculated_callback is not None:
|
||||
self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size)
|
||||
|
||||
hidden_states = torch.bmm(attention_slice, value)
|
||||
return hidden_states
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
# check if we already have a slicing strategy (this should only happen during cross-attention controlled generation)
|
||||
slicing_strategy_getter = self.slicing_strategy_getter
|
||||
if slicing_strategy_getter is not None:
|
||||
(dim, slice_size) = slicing_strategy_getter(self)
|
||||
if dim is not None:
|
||||
# print("using saved slicing strategy with dim", dim, "slice size", slice_size)
|
||||
if dim == 0:
|
||||
return self.einsum_op_slice_dim0(q, k, v, slice_size)
|
||||
elif dim == 1:
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
# fallback for when there is no saved strategy, or saved strategy does not slice
|
||||
mem_free_total = get_mem_free_total(q.device)
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def get_invokeai_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == "cuda":
|
||||
# print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device))
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
if q.device.type == "mps" or q.device.type == "cpu":
|
||||
if self.mem_total_gb >= 32:
|
||||
return self.einsum_op_mps_v1(q, k, v)
|
||||
return self.einsum_op_mps_v2(q, k, v)
|
||||
|
||||
# Smaller slices are faster due to L2/L3/SLC caches.
|
||||
# Tested on i7 with 8MB L3 cache.
|
||||
return self.einsum_op_tensor_mem(q, k, v, 32)
|
||||
|
||||
|
||||
def restore_default_cross_attention(
|
||||
model,
|
||||
is_running_diffusers: bool,
|
||||
restore_attention_processor: Optional[AttentionProcessor] = None,
|
||||
):
|
||||
if is_running_diffusers:
|
||||
unet = model
|
||||
unet.set_attn_processor(restore_attention_processor or AttnProcessor())
|
||||
else:
|
||||
remove_attention_function(model)
|
||||
|
||||
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: Context):
|
||||
def setup_cross_attention_control_attention_processors(unet: UNet2DConditionModel, context: CrossAttnControlContext):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
@ -362,170 +87,6 @@ def setup_cross_attention_control_attention_processors(unet: UNet2DConditionMode
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
|
||||
|
||||
def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]:
|
||||
cross_attention_class: type = InvokeAIDiffusersCrossAttention
|
||||
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
||||
attention_module_tuples = [
|
||||
(name, module)
|
||||
for name, module in model.named_modules()
|
||||
if isinstance(module, cross_attention_class) and which_attn in name
|
||||
]
|
||||
cross_attention_modules_in_model_count = len(attention_module_tuples)
|
||||
expected_count = 16
|
||||
if cross_attention_modules_in_model_count != expected_count:
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
|
||||
def inject_attention_function(unet, context: Context):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(module, suggested_attention_slice: torch.Tensor, dim, offset, slice_size):
|
||||
# memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
||||
|
||||
attention_slice = suggested_attention_slice
|
||||
|
||||
if context.get_should_save_maps(module.identifier):
|
||||
# print(module.identifier, "saving suggested_attention_slice of shape",
|
||||
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
||||
slice_to_save = attention_slice.to("cpu") if dim is not None else attention_slice
|
||||
context.save_slice(
|
||||
module.identifier,
|
||||
slice_to_save,
|
||||
dim=dim,
|
||||
offset=offset,
|
||||
slice_size=slice_size,
|
||||
)
|
||||
elif context.get_should_apply_saved_maps(module.identifier):
|
||||
# print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
||||
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
||||
|
||||
# slice may have been offloaded to CPU
|
||||
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
||||
|
||||
if context.is_tokens_cross_attention(module.identifier):
|
||||
index_map = context.cross_attention_index_map
|
||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||
this_attention_slice = suggested_attention_slice
|
||||
|
||||
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||
saved_mask = mask
|
||||
this_mask = 1 - mask
|
||||
attention_slice = remapped_saved_attention_slice * saved_mask + this_attention_slice * this_mask
|
||||
else:
|
||||
# just use everything
|
||||
attention_slice = saved_attention_slice
|
||||
|
||||
return attention_slice
|
||||
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for identifier, module in cross_attention_modules:
|
||||
module.identifier = identifier
|
||||
try:
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
module.set_slicing_strategy_getter(lambda module: context.get_slicing_strategy(identifier)) # noqa: B023
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def remove_attention_function(unet):
|
||||
cross_attention_modules = get_cross_attention_modules(
|
||||
unet, CrossAttentionType.TOKENS
|
||||
) + get_cross_attention_modules(unet, CrossAttentionType.SELF)
|
||||
for _identifier, module in cross_attention_modules:
|
||||
try:
|
||||
# clear wrangler callback
|
||||
module.set_attention_slice_wrangler(None)
|
||||
module.set_slicing_strategy_getter(None)
|
||||
except AttributeError as e:
|
||||
if is_attribute_error_about(e, "set_attention_slice_wrangler"):
|
||||
print(f"TODO: implement set_attention_slice_wrangler for {type(module)}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def is_attribute_error_about(error: AttributeError, attribute: str):
|
||||
if hasattr(error, "name"): # Python 3.10
|
||||
return error.name == attribute
|
||||
else: # Python 3.9
|
||||
return attribute in str(error)
|
||||
|
||||
|
||||
def get_mem_free_total(device):
|
||||
# only on cuda
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
stats = torch.cuda.memory_stats(device)
|
||||
mem_active = stats["active_bytes.all.current"]
|
||||
mem_reserved = stats["reserved_bytes.all.current"]
|
||||
mem_free_cuda, _ = torch.cuda.mem_get_info(device)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
return mem_free_total
|
||||
|
||||
|
||||
class InvokeAIDiffusersCrossAttention(diffusers.models.attention.Attention, InvokeAICrossAttentionMixin):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
InvokeAICrossAttentionMixin.__init__(self)
|
||||
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
# default_result = super()._attention(query, key, value)
|
||||
if attention_mask is not None:
|
||||
print(f"{type(self).__name__} ignoring passed-in attention_mask")
|
||||
attention_result = self.get_invokeai_attention_mem_efficient(query, key, value)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(attention_result)
|
||||
return hidden_states
|
||||
|
||||
|
||||
## 🧨diffusers implementation follows
|
||||
|
||||
|
||||
"""
|
||||
# base implementation
|
||||
|
||||
class AttnProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SwapCrossAttnContext:
|
||||
modified_text_embeddings: torch.Tensor
|
||||
@ -533,18 +94,6 @@ class SwapCrossAttnContext:
|
||||
mask: torch.Tensor # in the target space of the index_map
|
||||
cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list)
|
||||
|
||||
def __int__(
|
||||
self,
|
||||
cac_types_to_do: [CrossAttentionType],
|
||||
modified_text_embeddings: torch.Tensor,
|
||||
index_map: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
self.cross_attention_types_to_do = cac_types_to_do
|
||||
self.modified_text_embeddings = modified_text_embeddings
|
||||
self.index_map = index_map
|
||||
self.mask = mask
|
||||
|
||||
def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool:
|
||||
return attn_type in self.cross_attention_types_to_do
|
||||
|
||||
|
@ -1,100 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
|
||||
class AttentionMapSaver:
|
||||
def __init__(self, token_ids: range, latents_shape: torch.Size):
|
||||
self.token_ids = token_ids
|
||||
self.latents_shape = latents_shape
|
||||
# self.collated_maps = #torch.zeros([len(token_ids), latents_shape[0], latents_shape[1]])
|
||||
self.collated_maps: dict[str, torch.Tensor] = {}
|
||||
|
||||
def clear_maps(self):
|
||||
self.collated_maps = {}
|
||||
|
||||
def add_attention_maps(self, maps: torch.Tensor, key: str):
|
||||
"""
|
||||
Accumulate the given attention maps and store by summing with existing maps at the passed-in key (if any).
|
||||
:param maps: Attention maps to store. Expected shape [A, (H*W), N] where A is attention heads count, H and W are the map size (fixed per-key) and N is the number of tokens (typically 77).
|
||||
:param key: Storage key. If a map already exists for this key it will be summed with the incoming data. In this case the maps sizes (H and W) should match.
|
||||
:return: None
|
||||
"""
|
||||
key_and_size = f"{key}_{maps.shape[1]}"
|
||||
|
||||
# extract desired tokens
|
||||
maps = maps[:, :, self.token_ids]
|
||||
|
||||
# merge attention heads to a single map per token
|
||||
maps = torch.sum(maps, 0)
|
||||
|
||||
# store
|
||||
if key_and_size not in self.collated_maps:
|
||||
self.collated_maps[key_and_size] = torch.zeros_like(maps, device="cpu")
|
||||
self.collated_maps[key_and_size] += maps.cpu()
|
||||
|
||||
def write_maps_to_disk(self, path: str):
|
||||
pil_image = self.get_stacked_maps_image()
|
||||
if pil_image is not None:
|
||||
pil_image.save(path, "PNG")
|
||||
|
||||
def get_stacked_maps_image(self) -> Optional[Image.Image]:
|
||||
"""
|
||||
Scale all collected attention maps to the same size, blend them together and return as an image.
|
||||
:return: An image containing a vertical stack of blended attention maps, one for each requested token.
|
||||
"""
|
||||
num_tokens = len(self.token_ids)
|
||||
if num_tokens == 0:
|
||||
return None
|
||||
|
||||
latents_height = self.latents_shape[0]
|
||||
latents_width = self.latents_shape[1]
|
||||
|
||||
merged = None
|
||||
|
||||
for _key, maps in self.collated_maps.items():
|
||||
# maps has shape [(H*W), N] for N tokens
|
||||
# but we want [N, H, W]
|
||||
this_scale_factor = math.sqrt(maps.shape[0] / (latents_width * latents_height))
|
||||
this_maps_height = int(float(latents_height) * this_scale_factor)
|
||||
this_maps_width = int(float(latents_width) * this_scale_factor)
|
||||
# and we need to do some dimension juggling
|
||||
maps = torch.reshape(
|
||||
torch.swapdims(maps, 0, 1),
|
||||
[num_tokens, this_maps_height, this_maps_width],
|
||||
)
|
||||
|
||||
# scale to output size if necessary
|
||||
if this_scale_factor != 1:
|
||||
maps = tv_resize(maps, [latents_height, latents_width], InterpolationMode.BICUBIC)
|
||||
|
||||
# normalize
|
||||
maps_min = torch.min(maps)
|
||||
maps_range = torch.max(maps) - maps_min
|
||||
# print(f"map {key} size {[this_maps_width, this_maps_height]} range {[maps_min, maps_min + maps_range]}")
|
||||
maps_normalized = (maps - maps_min) / maps_range
|
||||
# expand to (-0.1, 1.1) and clamp
|
||||
maps_normalized_expanded = maps_normalized * 1.1 - 0.05
|
||||
maps_normalized_expanded_clamped = torch.clamp(maps_normalized_expanded, 0, 1)
|
||||
|
||||
# merge together, producing a vertical stack
|
||||
maps_stacked = torch.reshape(
|
||||
maps_normalized_expanded_clamped,
|
||||
[num_tokens * latents_height, latents_width],
|
||||
)
|
||||
|
||||
if merged is None:
|
||||
merged = maps_stacked
|
||||
else:
|
||||
# screen blend
|
||||
merged = 1 - (1 - maps_stacked) * (1 - merged)
|
||||
|
||||
if merged is None:
|
||||
return None
|
||||
|
||||
merged_bytes = merged.mul(0xFF).byte()
|
||||
return Image.fromarray(merged_bytes.numpy(), mode="L")
|
209
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
209
invokeai/backend/stable_diffusion/diffusion/custom_attention.py
Normal file
@ -0,0 +1,209 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""A custom implementation of AttnProcessor2_0 that supports additional Invoke features.
|
||||
|
||||
This implementation is based on
|
||||
https://github.com/huggingface/diffusers/blame/fcfa270fbd1dc294e2f3a505bae6bcb791d721c3/src/diffusers/models/attention_processor.py#L1204
|
||||
|
||||
Supported custom features:
|
||||
- IP-Adapter
|
||||
- Regional prompt attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ip_adapter_weights: Optional[list[IPAttentionProcessorWeights]] = None,
|
||||
ip_adapter_scales: Optional[list[float]] = None,
|
||||
):
|
||||
"""Initialize a CustomAttnProcessor2_0.
|
||||
|
||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||
layer-specific are passed to __init__().
|
||||
|
||||
Args:
|
||||
ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights
|
||||
for the i'th IP-Adapter.
|
||||
ip_adapter_scales: The IP-Adapter attention scales. ip_adapter_scales[i] contains the attention scale for
|
||||
the i'th IP-Adapter.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._ip_adapter_weights = ip_adapter_weights
|
||||
self._ip_adapter_scales = ip_adapter_scales
|
||||
|
||||
assert (self._ip_adapter_weights is None) == (self._ip_adapter_scales is None)
|
||||
if self._ip_adapter_weights is not None:
|
||||
assert len(ip_adapter_weights) == len(ip_adapter_scales)
|
||||
|
||||
def _is_ip_adapter_enabled(self) -> bool:
|
||||
return self._ip_adapter_weights is not None
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
# For regional prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
# For IP-Adapter:
|
||||
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
|
||||
Args:
|
||||
regional_prompt_data: The regional prompt data for the current batch. If not None, this will be used to
|
||||
apply regional prompt masking.
|
||||
ip_adapter_image_prompt_embeds: The IP-Adapter image prompt embeddings for the current batch.
|
||||
ip_adapter_image_prompt_embeds[i] contains the image prompt embeddings for the i'th IP-Adapter. Each
|
||||
tensor has shape (batch_size, num_ip_images, seq_len, ip_embedding_len).
|
||||
"""
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Handle regional prompt attention masks.
|
||||
if is_cross_attention and regional_prompt_data is not None:
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
|
||||
)
|
||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
# Apply IP-Adapter conditioning.
|
||||
if is_cross_attention and self._is_ip_adapter_enabled():
|
||||
if self._is_ip_adapter_enabled():
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
for ipa_embed, ipa_weights, scale in zip(
|
||||
ip_adapter_image_prompt_embeds, self._ip_adapter_weights, self._ip_adapter_scales, strict=True
|
||||
):
|
||||
# The batch dimensions should match.
|
||||
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The token_len dimensions should match.
|
||||
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then ip_adapter_image_prompt_embeds should not be passed in.
|
||||
assert ip_adapter_image_prompt_embeds is None
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
|
||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
||||
|
||||
@classmethod
|
||||
def from_regions(
|
||||
cls,
|
||||
regions: list[TextConditioningRegions],
|
||||
key_seq_len: int,
|
||||
# TODO(ryand): Pass in a list of downscale factors?
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Construct a `RegionalPromptData` object.
|
||||
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||
explicitly to be sure.
|
||||
"""
|
||||
attn_masks_by_seq_len = {}
|
||||
|
||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||
# length of s.
|
||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
for batch_sample_regions in regions:
|
||||
batch_attn_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
_, num_prompts, h, w = batch_masks.shape
|
||||
query_seq_len = h * w
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
||||
|
||||
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
||||
# `encoder_hidden_states`.
|
||||
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
||||
# TODO(ryand): What device / dtype should this be?
|
||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||
:, prompt_idx, :, :
|
||||
]
|
||||
|
||||
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||
# regions to be lost entirely.
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||
# potentially use a weighted mask rather than a binary mask.
|
||||
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
||||
|
||||
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
||||
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
||||
attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
||||
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
||||
)
|
||||
|
||||
return cls(attn_masks_by_seq_len)
|
||||
|
||||
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
||||
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
|
||||
|
||||
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
|
||||
it changes at each downscaling level in the model.
|
||||
|
||||
key_seq_len is the length of the expected prompt embeddings.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masks.
|
||||
shape: (batch_size, query_seq_len, key_seq_len).
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
return self._attn_masks_by_seq_len[query_seq_len]
|
@ -10,20 +10,20 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
from .cross_attention_control import (
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
CrossAttnControlContext,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
ModelForwardCallback: TypeAlias = Union[
|
||||
# x, t, conditioning, Optional[cross-attention kwargs]
|
||||
@ -58,7 +58,6 @@ class InvokeAIDiffuserComponent:
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
@ -69,14 +68,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = Context(
|
||||
self.cross_attention_control_context = CrossAttnControlContext(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
@ -87,27 +84,6 @@ class InvokeAIDiffuserComponent:
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
# sliced tokens attention map saving is not implemented
|
||||
return
|
||||
saver.add_attention_maps(slice, key)
|
||||
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for identifier, module in tokens_cross_attention_modules:
|
||||
key = "down" if identifier.startswith("down") else "up" if identifier.startswith("up") else "mid"
|
||||
module.set_attention_slice_calculated_callback(
|
||||
lambda slice, dim, offset, slice_size, key=key: callback(slice, dim, offset, slice_size, key)
|
||||
)
|
||||
|
||||
def remove_attention_map_saving(self):
|
||||
tokens_cross_attention_modules = get_cross_attention_modules(self.model, CrossAttentionType.TOKENS)
|
||||
for _, module in tokens_cross_attention_modules:
|
||||
module.set_attention_slice_calculated_callback(None)
|
||||
|
||||
def do_controlnet_step(
|
||||
self,
|
||||
@ -116,7 +92,7 @@ class InvokeAIDiffuserComponent:
|
||||
timestep: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
conditioning_data,
|
||||
conditioning_data: TextConditioningData,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
|
||||
@ -149,38 +125,30 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
],
|
||||
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
(
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
@ -224,68 +192,54 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
conditioning_data, # TODO: type
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
**kwargs,
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
cross_attention_control_types_to_do = (
|
||||
self.cross_attention_control_context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
)
|
||||
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_cross_attention_controlled_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
elif self.sequential_guidance:
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
sample,
|
||||
timestep,
|
||||
conditioning_data,
|
||||
**kwargs,
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
ip_adapter_conditioning=ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_symmetry(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
@ -333,58 +287,79 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {}
|
||||
if ip_adapter_conditioning is not None:
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
uncond_text = conditioning_data.uncond_text
|
||||
cond_text = conditioning_data.cond_text
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
conditioning_data.text_embeddings.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
conditioning_data.text_embeddings.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
||||
uncond_text.embeds, cond_text.embeds
|
||||
)
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
|
||||
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
_, key_seq_len, _ = both_conditionings.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||
regions=regions, key_seq_len=key_seq_len
|
||||
)
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
unconditioned_next_x, conditioned_next_x = both_results.chunk(2)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
@ -393,15 +368,19 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
**kwargs,
|
||||
conditioning_data: TextConditioningData,
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
cross_attention_control_types_to_do: list[CrossAttentionType],
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||
slower execution speed.
|
||||
"""
|
||||
# low-memory sequential path
|
||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||
# and T2I-Adapter residuals into two chunks.
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
@ -410,7 +389,6 @@ class InvokeAIDiffuserComponent:
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
@ -419,151 +397,111 @@ class InvokeAIDiffuserComponent:
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# Run unconditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
# If cross-attention control is enabled, prepare the SwapCrossAttnContext.
|
||||
cross_attn_processor_context = None
|
||||
if self.cross_attention_control_context is not None:
|
||||
# Note that the SwapCrossAttnContext is initialized with an empty list of cross_attention_types_to_do.
|
||||
# This list is empty because cross-attention control is not applied in the unconditioned pass. This field
|
||||
# will be populated before the conditioned pass.
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=self.cross_attention_control_context.arguments.edited_conditioning,
|
||||
index_map=self.cross_attention_control_context.cross_attention_index_map,
|
||||
mask=self.cross_attention_control_context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
#####################
|
||||
# Unconditioned pass
|
||||
#####################
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the unconditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the unconditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the unconditioned pass.
|
||||
if conditioning_data.uncond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.uncond_text.embeds.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.uncond_regions], key_seq_len=key_seq_len
|
||||
)
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Run conditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
###################
|
||||
# Conditioned pass
|
||||
###################
|
||||
|
||||
cross_attention_kwargs = {}
|
||||
|
||||
# Prepare IP-Adapter cross-attention kwargs for the conditioned pass.
|
||||
if ip_adapter_conditioning is not None:
|
||||
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
cross_attention_kwargs["ip_adapter_image_prompt_embeds"] = [
|
||||
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
|
||||
# Prepare cross-attention control kwargs for the conditioned pass.
|
||||
if cross_attn_processor_context is not None:
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
cross_attention_kwargs["swap_cross_attn_context"] = cross_attn_processor_context
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Prepare prompt regions for the conditioned pass.
|
||||
if conditioning_data.cond_regions is not None:
|
||||
_, key_seq_len, _ = conditioning_data.cond_text.embeds.shape
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
||||
regions=[conditioning_data.cond_regions], key_seq_len=key_seq_len
|
||||
)
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
if down_block_additional_residuals is not None:
|
||||
uncond_down_block, cond_down_block = [], []
|
||||
for down_block in down_block_additional_residuals:
|
||||
_uncond_down, _cond_down = down_block.chunk(2)
|
||||
uncond_down_block.append(_uncond_down)
|
||||
cond_down_block.append(_cond_down)
|
||||
|
||||
uncond_down_intrablock, cond_down_intrablock = None, None
|
||||
down_intrablock_additional_residuals = kwargs.pop("down_intrablock_additional_residuals", None)
|
||||
if down_intrablock_additional_residuals is not None:
|
||||
uncond_down_intrablock, cond_down_intrablock = [], []
|
||||
for down_intrablock in down_intrablock_additional_residuals:
|
||||
_uncond_down, _cond_down = down_intrablock.chunk(2)
|
||||
uncond_down_intrablock.append(_uncond_down)
|
||||
cond_down_intrablock.append(_cond_down)
|
||||
|
||||
uncond_mid_block, cond_mid_block = None, None
|
||||
mid_block_additional_residual = kwargs.pop("mid_block_additional_residual", None)
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
cross_attn_processor_context = SwapCrossAttnContext(
|
||||
modified_text_embeddings=context.arguments.edited_conditioning,
|
||||
index_map=context.cross_attention_index_map,
|
||||
mask=context.cross_attention_mask,
|
||||
cross_attention_types_to_do=[],
|
||||
)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# no cross attention for unconditioning (negative prompt)
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
down_intrablock_additional_residuals=uncond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
}
|
||||
|
||||
# do requested cross attention types for conditioning (positive prompt)
|
||||
cross_attn_processor_context.cross_attention_types_to_do = cross_attention_control_types_to_do
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
{"swap_cross_attn_context": cross_attn_processor_context},
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
down_intrablock_additional_residuals=cond_down_intrablock,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
@ -572,115 +510,3 @@ class InvokeAIDiffuserComponent:
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def apply_symmetry(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through: float,
|
||||
) -> torch.Tensor:
|
||||
# Reset our last percent through if this is our first step.
|
||||
if percent_through == 0.0:
|
||||
self.last_percent_through = 0.0
|
||||
|
||||
if postprocessing_settings is None:
|
||||
return latents
|
||||
|
||||
# Check for out of bounds
|
||||
h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct
|
||||
if h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0):
|
||||
h_symmetry_time_pct = None
|
||||
|
||||
v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct
|
||||
if v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0):
|
||||
v_symmetry_time_pct = None
|
||||
|
||||
dev = latents.device.type
|
||||
|
||||
latents.to(device="cpu")
|
||||
|
||||
if (
|
||||
h_symmetry_time_pct is not None
|
||||
and self.last_percent_through < h_symmetry_time_pct
|
||||
and percent_through >= h_symmetry_time_pct
|
||||
):
|
||||
# Horizontal symmetry occurs on the 3rd dimension of the latent
|
||||
width = latents.shape[3]
|
||||
x_flipped = torch.flip(latents, dims=[3])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, :, 0 : int(width / 2)],
|
||||
x_flipped[:, :, :, int(width / 2) : int(width)],
|
||||
],
|
||||
dim=3,
|
||||
)
|
||||
|
||||
if (
|
||||
v_symmetry_time_pct is not None
|
||||
and self.last_percent_through < v_symmetry_time_pct
|
||||
and percent_through >= v_symmetry_time_pct
|
||||
):
|
||||
# Vertical symmetry occurs on the 2nd dimension of the latent
|
||||
height = latents.shape[2]
|
||||
y_flipped = torch.flip(latents, dims=[2])
|
||||
latents = torch.cat(
|
||||
[
|
||||
latents[:, :, 0 : int(height / 2)],
|
||||
y_flipped[:, :, int(height / 2) : int(height)],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
self.last_percent_through = percent_through
|
||||
return latents.to(device=dev)
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = (
|
||||
c_or_weighted_c_list if isinstance(c_or_weighted_c_list, list) else [(c_or_weighted_c_list, 1)]
|
||||
)
|
||||
|
||||
# below is fugly omg
|
||||
conditionings = [uc] + [c for c, weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c, weight in weighted_cond_list]
|
||||
chunk_count = math.ceil(len(conditionings) / 2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index * 2
|
||||
chunk_size = min(2, len(conditionings) - offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset : offset + 2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -1,52 +1,55 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.custom_attention import CustomAttnProcessor2_0
|
||||
|
||||
|
||||
class UNetPatcher:
|
||||
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
|
||||
class UNetAttentionPatcher:
|
||||
"""A class for patching a UNet with CustomAttnProcessor2_0 attention layers."""
|
||||
|
||||
def __init__(self, ip_adapters: list[IPAdapter]):
|
||||
def __init__(self, ip_adapters: Optional[list[IPAdapter]]):
|
||||
self._ip_adapters = ip_adapters
|
||||
self._scales = [1.0] * len(self._ip_adapters)
|
||||
self._ip_adapter_scales = None
|
||||
|
||||
if self._ip_adapters is not None:
|
||||
self._ip_adapter_scales = [1.0] * len(self._ip_adapters)
|
||||
|
||||
def set_scale(self, idx: int, value: float):
|
||||
self._scales[idx] = value
|
||||
self._ip_adapter_scales[idx] = value
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
|
||||
weights into them.
|
||||
weights into them (if IP-Adapters are being applied).
|
||||
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
"""
|
||||
# Construct a dict of attention processors based on the UNet's architecture.
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(unet.attn_processors.keys()):
|
||||
if name.endswith("attn1.processor"):
|
||||
attn_procs[name] = AttnProcessor2_0()
|
||||
if name.endswith("attn1.processor") or self._ip_adapters is None:
|
||||
# "attn1" processors do not use IP-Adapters.
|
||||
attn_procs[name] = CustomAttnProcessor2_0()
|
||||
else:
|
||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||
attn_procs[name] = IPAttnProcessor2_0(
|
||||
attn_procs[name] = CustomAttnProcessor2_0(
|
||||
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
|
||||
self._scales,
|
||||
self._ip_adapter_scales,
|
||||
)
|
||||
return attn_procs
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
|
||||
"""A context manager that patches `unet` with IP-Adapter attention processors."""
|
||||
|
||||
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
|
||||
attn_procs = self._prepare_attention_processors(unet)
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
|
||||
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
|
||||
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
|
||||
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
|
||||
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
|
||||
unet.set_attn_processor(attn_procs)
|
||||
yield None
|
||||
finally:
|
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
||||
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
||||
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
|
||||
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
|
||||
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
|
Reference in New Issue
Block a user