Move regional prompt concatenation further up the stack. This solves a number of issues.

This commit is contained in:
Ryan Dick 2024-02-28 20:11:47 -05:00
parent 53ebca58ff
commit 5f49e7ae26
5 changed files with 244 additions and 266 deletions

View File

@ -1,5 +1,4 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect import inspect
import math import math
from contextlib import ExitStack from contextlib import ExitStack
@ -9,6 +8,7 @@ from typing import List, Literal, Optional, Union
import einops import einops
import numpy as np import numpy as np
import torch import torch
import torchvision
import torchvision.transforms as T import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
@ -44,8 +44,10 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
Range,
SDXLConditioningInfo, SDXLConditioningInfo,
TextConditioningData, TextConditioningData,
TextConditioningRegions,
) )
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -334,7 +336,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext, context: InvocationContext,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
): ) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
# Normalize cond_field to a list. # Normalize cond_field to a list.
cond_list = cond_field cond_list = cond_field
if not isinstance(cond_list, list): if not isinstance(cond_list, list):
@ -353,12 +356,111 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
# embedding is passed in and CAC is being used.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
#
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
# most likely to be a global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions
return BasicConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
), regions
def get_conditioning_data( def get_conditioning_data(
self, self,
context: InvocationContext, context: InvocationContext,
scheduler,
unet, unet,
seed, latent_height: int,
latent_width: int,
) -> TextConditioningData: ) -> TextConditioningData:
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.positive_conditioning, context, unet.device, unet.dtype self.positive_conditioning, context, unet.device, unet.dtype
@ -366,12 +468,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.negative_conditioning, context, unet.device, unet.dtype self.negative_conditioning, context, unet.device, unet.dtype
) )
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
)
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
)
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(
uncond_text_embeddings=uncond_text_embeddings, uncond_text=uncond_text_embedding,
uncond_text_embedding_masks=uncond_text_embedding_masks, cond_text=cond_text_embedding,
cond_text_embeddings=cond_text_embeddings, uncond_regions=uncond_regions,
cond_text_embedding_masks=cond_text_embedding_masks, cond_regions=cond_regions,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
) )
@ -761,7 +874,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) _, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
controlnet_data = self.prep_control_data( controlnet_data = self.prep_control_data(
context=context, context=context,

View File

@ -411,13 +411,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents
extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
use_cross_attention_control = ( use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
) )
use_ip_adapter = ip_adapter_data is not None use_ip_adapter = ip_adapter_data is not None
# HACK(ryand): Fix this logic. # HACK(ryand): Fix this logic.
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1 use_regional_prompting = conditioning_data.cond_regions is not None
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
raise Exception( raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)." "Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."

View File

@ -8,6 +8,11 @@ from .cross_attention_control import Arguments
@dataclass @dataclass
class ExtraConditioningInfo: class ExtraConditioningInfo:
"""Extra conditioning information produced by Compel.
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
"""
tokens_count_including_eos_bos: int tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None cross_attention_control_args: Optional[Arguments] = None
@ -54,20 +59,48 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class TextConditioningData: class Range:
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] start: int
uncond_text_embedding_masks: list[Optional[torch.Tensor]] end: int
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
cond_text_embedding_masks: list[Optional[torch.Tensor]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). class TextConditioningRegions:
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). def __init__(self, masks: torch.Tensor, ranges: list[Range]):
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # A binary mask indicating the regions of the image that the prompt should be applied to.
images that are closely linked to the text `prompt`, usually at the expense of lower image quality. # Shape: (1, num_prompts, height, width)
""" # Dtype: torch.bool
guidance_scale: Union[float, List[float]] self.masks = masks
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf) # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
""" # ranges[i] contains the embedding range for the i'th prompt / mask.
guidance_rescale_multiplier: float = 0 self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData:
def __init__(
self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -1,5 +1,4 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
@ -8,11 +7,9 @@ from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention, AttnProcessor2_0 from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND from diffusers.utils import USE_PEFT_BACKEND
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
@dataclass TextConditioningRegions,
class Range: )
start: int
end: int
class RegionalPromptData: class RegionalPromptData:
@ -20,10 +17,9 @@ class RegionalPromptData:
self._attn_masks_by_seq_len = attn_masks_by_seq_len self._attn_masks_by_seq_len = attn_masks_by_seq_len
@classmethod @classmethod
def from_masks_and_ranges( def from_regions(
cls, cls,
masks: list[torch.Tensor], regions: list[TextConditioningRegions],
embedding_ranges: list[list[Range]],
key_seq_len: int, key_seq_len: int,
# TODO(ryand): Pass in a list of downscale factors? # TODO(ryand): Pass in a list of downscale factors?
max_downscale_factor: int = 8, max_downscale_factor: int = 8,
@ -31,14 +27,8 @@ class RegionalPromptData:
"""Construct a `RegionalPromptData` object. """Construct a `RegionalPromptData` object.
Args: Args:
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch. regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the batch.
prompt should be applied, and 0.0 elsewhere.
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
explicitly to be sure. explicitly to be sure.
@ -48,11 +38,11 @@ class RegionalPromptData:
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# length of s. # length of s.
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = [] batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True): for batch_sample_regions in regions:
batch_attn_masks_by_seq_len.append({}) batch_attn_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied. # Convert the bool masks to float masks so that max pooling can be applied.
batch_masks = batch_masks.to(dtype=torch.float32) batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached. # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1 downscale_factor = 1
@ -69,7 +59,7 @@ class RegionalPromptData:
# TODO(ryand): What device / dtype should this be? # TODO(ryand): What device / dtype should this be?
attn_mask = torch.zeros((1, query_seq_len, key_seq_len)) attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
for prompt_idx, embedding_range in enumerate(batch_ranges): for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[ attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
:, prompt_idx, :, : :, prompt_idx, :, :
] ]

View File

@ -5,19 +5,18 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torchvision
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo, ExtraConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
SDXLConditioningInfo, Range,
TextConditioningData, TextConditioningData,
TextConditioningRegions,
) )
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
from .cross_attention_control import ( from .cross_attention_control import (
CrossAttentionType, CrossAttentionType,
@ -36,143 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
] ]
class RegionalTextConditioningInfo:
def __init__(
self,
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
masks: Optional[torch.Tensor] = None,
embedding_ranges: Optional[list[Range]] = None,
):
"""Initialize a RegionalTextConditioningInfo.
Args:
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
after concatenating the embeddings for all regions.
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
"""
self.text_conditioning = text_conditioning
self.masks = masks
self.embedding_ranges = embedding_ranges
assert (self.masks is None) == (self.embedding_ranges is None)
if self.masks is not None:
assert self.masks.shape[1] == len(self.embedding_ranges)
def has_region_masks(self):
if self.masks is None:
return False
return any(mask is not None for mask in self.masks)
def is_sdxl(self):
return isinstance(self.text_conditioning, SDXLConditioningInfo)
@classmethod
def _preprocess_regional_prompt_mask(
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
@classmethod
def from_text_conditioning_and_masks(
cls,
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
):
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
#
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
# most likely to be a global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
if not all_masks_are_none:
processed_masks = torch.cat(processed_masks, dim=1)
else:
processed_masks = None
embedding_ranges = None
if is_sdxl:
return cls(
text_conditioning=SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
return cls(
text_conditioning=BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
""" """
The aim of this component is to provide a single place for code that can be applied identically to The aim of this component is to provide a single place for code that can be applied identically to
@ -233,10 +95,6 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData, conditioning_data: TextConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
@ -267,25 +125,30 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(cond_text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": cond_text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": cond_text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
encoder_hidden_states = cond_text_embeddings.embeds encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(cond_text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0 [
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0 [conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
dim=0,
), ),
} }
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch( (encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
uncond_text_embeddings.embeds, cond_text_embeddings.embeds conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
) )
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
@ -440,52 +303,7 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
# stack.
cross_attention_kwargs = None cross_attention_kwargs = None
_, _, h, w = x.shape
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.cond_text_embeddings,
masks=conditioning_data.cond_text_embedding_masks,
latent_height=h,
latent_width=w,
)
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.uncond_text_embeddings,
masks=conditioning_data.uncond_text_embedding_masks,
latent_height=h,
latent_width=w,
)
if cond_text.has_region_masks() or uncond_text.has_region_masks():
masks = []
embedding_ranges = []
for c in [uncond_text, cond_text]:
if c.has_region_masks():
masks.append(c.masks)
embedding_ranges.append(c.embedding_ranges)
else:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
# embeddings will be padded to match this length.
key_seq_len = 0
for c in [uncond_text, cond_text]:
_, seq_len, _ = c.text_conditioning.embeds.shape
if seq_len > key_seq_len:
key_seq_len = seq_len
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
)
}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning. # TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if ip_adapter_conditioning is not None: if ip_adapter_conditioning is not None:
@ -499,20 +317,45 @@ class InvokeAIDiffuserComponent:
] ]
} }
uncond_text = conditioning_data.uncond_text
cond_text = conditioning_data.cond_text
added_cond_kwargs = None added_cond_kwargs = None
if cond_text.is_sdxl(): if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0 "time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
),
"time_ids": torch.cat(
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
),
} }
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds uncond_text.embeds, cond_text.embeds
) )
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len)
}
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
sigma_twice, sigma_twice,
@ -542,9 +385,6 @@ class InvokeAIDiffuserComponent:
slower execution speed. slower execution speed.
""" """
assert len(conditioning_data.cond_text_embeddings) == 1
text_embeddings = conditioning_data.cond_text_embeddings[0]
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks. # and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
@ -602,18 +442,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(text_embeddings) is SDXLConditioningInfo if conditioning_data.is_sdxl():
if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds, "text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.uncond_text_embeddings.add_time_ids, "time_ids": conditioning_data.uncond_text.add_time_ids,
} }
# Run unconditioned UNet denoising (i.e. negative prompt). # Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.uncond_text_embeddings.embeds, conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
@ -644,17 +483,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the conditioned pass. # Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
# Run conditioned UNet denoising (i.e. positive prompt). # Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
text_embeddings.embeds, conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,