Move regional prompt concatenation further up the stack. This solves a number of issues.

This commit is contained in:
Ryan Dick 2024-02-28 20:11:47 -05:00
parent 53ebca58ff
commit 5f49e7ae26
5 changed files with 244 additions and 266 deletions

View File

@ -1,5 +1,4 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import math
from contextlib import ExitStack
@ -9,6 +8,7 @@ from typing import List, Literal, Optional, Union
import einops
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
@ -44,8 +44,10 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
IPAdapterConditioningInfo,
Range,
SDXLConditioningInfo,
TextConditioningData,
TextConditioningRegions,
)
from ...backend.model_management.lora import ModelPatcher
@ -334,7 +336,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
device: torch.device,
dtype: torch.dtype,
):
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
"""Get the text embeddings and masks from the input conditioning fields."""
# Normalize cond_field to a list.
cond_list = cond_field
if not isinstance(cond_list, list):
@ -353,12 +356,111 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
cur_text_embedding_len = 0
processed_masks = []
embedding_ranges = []
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
# embedding is passed in and CAC is being used.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
#
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
# most likely to be a global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
), regions
return BasicConditioningInfo(
embeds=text_embedding,
# TODO(ryand): This should not be hard-coded to None.
extra_conditioning=None,
), regions
def get_conditioning_data(
self,
context: InvocationContext,
scheduler,
unet,
seed,
latent_height: int,
latent_width: int,
) -> TextConditioningData:
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.positive_conditioning, context, unet.device, unet.dtype
@ -366,12 +468,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.negative_conditioning, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
)
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
)
conditioning_data = TextConditioningData(
uncond_text_embeddings=uncond_text_embeddings,
uncond_text_embedding_masks=uncond_text_embedding_masks,
cond_text_embeddings=cond_text_embeddings,
cond_text_embedding_masks=cond_text_embedding_masks,
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
)
@ -761,7 +874,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
)
controlnet_data = self.prep_control_data(
context=context,

View File

@ -411,13 +411,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents
extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
# HACK(ryand): Fix this logic.
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1
use_regional_prompting = conditioning_data.cond_regions is not None
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."

View File

@ -8,6 +8,11 @@ from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
"""Extra conditioning information produced by Compel.
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
"""
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@ -54,20 +59,48 @@ class IPAdapterConditioningInfo:
@dataclass
class TextConditioningData:
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
uncond_text_embedding_masks: list[Optional[torch.Tensor]]
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
cond_text_embedding_masks: list[Optional[torch.Tensor]]
class Range:
start: int
end: int
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
guidance_scale: Union[float, List[float]]
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
"""
guidance_rescale_multiplier: float = 0
class TextConditioningRegions:
def __init__(self, masks: torch.Tensor, ranges: list[Range]):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData:
def __init__(
self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -1,5 +1,4 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
import torch
@ -8,11 +7,9 @@ from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND
@dataclass
class Range:
start: int
end: int
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:
@ -20,10 +17,9 @@ class RegionalPromptData:
self._attn_masks_by_seq_len = attn_masks_by_seq_len
@classmethod
def from_masks_and_ranges(
def from_regions(
cls,
masks: list[torch.Tensor],
embedding_ranges: list[list[Range]],
regions: list[TextConditioningRegions],
key_seq_len: int,
# TODO(ryand): Pass in a list of downscale factors?
max_downscale_factor: int = 8,
@ -31,14 +27,8 @@ class RegionalPromptData:
"""Construct a `RegionalPromptData` object.
Args:
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
prompt should be applied, and 0.0 elsewhere.
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
explicitly to be sure.
@ -48,11 +38,11 @@ class RegionalPromptData:
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# length of s.
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True):
for batch_sample_regions in regions:
batch_attn_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_masks = batch_masks.to(dtype=torch.float32)
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
@ -69,7 +59,7 @@ class RegionalPromptData:
# TODO(ryand): What device / dtype should this be?
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
for prompt_idx, embedding_range in enumerate(batch_ranges):
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
:, prompt_idx, :, :
]

View File

@ -5,19 +5,18 @@ from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
import torchvision
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
IPAdapterConditioningInfo,
SDXLConditioningInfo,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
from .cross_attention_control import (
CrossAttentionType,
@ -36,143 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
]
class RegionalTextConditioningInfo:
def __init__(
self,
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
masks: Optional[torch.Tensor] = None,
embedding_ranges: Optional[list[Range]] = None,
):
"""Initialize a RegionalTextConditioningInfo.
Args:
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
after concatenating the embeddings for all regions.
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
"""
self.text_conditioning = text_conditioning
self.masks = masks
self.embedding_ranges = embedding_ranges
assert (self.masks is None) == (self.embedding_ranges is None)
if self.masks is not None:
assert self.masks.shape[1] == len(self.embedding_ranges)
def has_region_masks(self):
if self.masks is None:
return False
return any(mask is not None for mask in self.masks)
def is_sdxl(self):
return isinstance(self.text_conditioning, SDXLConditioningInfo)
@classmethod
def _preprocess_regional_prompt_mask(
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
@classmethod
def from_text_conditioning_and_masks(
cls,
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
):
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
#
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
# most likely to be a global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
if not all_masks_are_none:
processed_masks = torch.cat(processed_masks, dim=1)
else:
processed_masks = None
embedding_ranges = None
if is_sdxl:
return cls(
text_conditioning=SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
return cls(
text_conditioning=BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@ -233,10 +95,6 @@ class InvokeAIDiffuserComponent:
conditioning_data: TextConditioningData,
):
down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
# control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list)
@ -267,25 +125,30 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(cond_text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": cond_text_embeddings.pooled_embeds,
"time_ids": cond_text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
encoder_hidden_states = cond_text_embeddings.embeds
encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None
else:
if type(cond_text_embeddings) is SDXLConditioningInfo:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0
[
conditioning_data.uncond_text.pooled_embeds,
conditioning_data.cond_text.pooled_embeds,
],
dim=0,
),
"time_ids": torch.cat(
[uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
dim=0,
),
}
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
uncond_text_embeddings.embeds, cond_text_embeddings.embeds
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
)
if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step
@ -440,52 +303,7 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
# stack.
cross_attention_kwargs = None
_, _, h, w = x.shape
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.cond_text_embeddings,
masks=conditioning_data.cond_text_embedding_masks,
latent_height=h,
latent_width=w,
)
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.uncond_text_embeddings,
masks=conditioning_data.uncond_text_embedding_masks,
latent_height=h,
latent_width=w,
)
if cond_text.has_region_masks() or uncond_text.has_region_masks():
masks = []
embedding_ranges = []
for c in [uncond_text, cond_text]:
if c.has_region_masks():
masks.append(c.masks)
embedding_ranges.append(c.embedding_ranges)
else:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
# embeddings will be padded to match this length.
key_seq_len = 0
for c in [uncond_text, cond_text]:
_, seq_len, _ = c.text_conditioning.embeds.shape
if seq_len > key_seq_len:
key_seq_len = seq_len
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
)
}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if ip_adapter_conditioning is not None:
@ -499,20 +317,45 @@ class InvokeAIDiffuserComponent:
]
}
uncond_text = conditioning_data.uncond_text
cond_text = conditioning_data.cond_text
added_cond_kwargs = None
if cond_text.is_sdxl():
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0
),
"time_ids": torch.cat(
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
),
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds
uncond_text.embeds, cond_text.embeds
)
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = []
for c, r in [
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
(conditioning_data.cond_text, conditioning_data.cond_regions),
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len)
}
both_results = self.model_forward_callback(
x_twice,
sigma_twice,
@ -542,9 +385,6 @@ class InvokeAIDiffuserComponent:
slower execution speed.
"""
assert len(conditioning_data.cond_text_embeddings) == 1
text_embeddings = conditioning_data.cond_text_embeddings[0]
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None
@ -602,18 +442,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds,
"time_ids": conditioning_data.uncond_text_embeddings.add_time_ids,
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.uncond_text.add_time_ids,
}
# Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback(
x,
sigma,
conditioning_data.uncond_text_embeddings.embeds,
conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
@ -644,17 +483,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None
if is_sdxl:
if conditioning_data.is_sdxl():
added_cond_kwargs = {
"text_embeds": text_embeddings.pooled_embeds,
"time_ids": text_embeddings.add_time_ids,
"text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.cond_text.add_time_ids,
}
# Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback(
x,
sigma,
text_embeddings.embeds,
conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,