mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move regional prompt concatenation further up the stack. This solves a number of issues.
This commit is contained in:
parent
53ebca58ff
commit
5f49e7ae26
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@ -9,6 +8,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
@ -44,8 +44,10 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
|
Range,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
@ -334,7 +336,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
# Normalize cond_field to a list.
|
# Normalize cond_field to a list.
|
||||||
cond_list = cond_field
|
cond_list = cond_field
|
||||||
if not isinstance(cond_list, list):
|
if not isinstance(cond_list, list):
|
||||||
@ -353,12 +356,111 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return text_embeddings, text_embeddings_masks
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
mask = tf(mask)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def concat_regional_text_embeddings(
|
||||||
|
self,
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
|
||||||
|
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||||
|
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
|
||||||
|
# embedding is passed in and CAC is being used.
|
||||||
|
assert (
|
||||||
|
text_embedding_info.extra_conditioning is None
|
||||||
|
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||||
|
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
||||||
|
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
#
|
||||||
|
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
||||||
|
# most likely to be a global prompt.
|
||||||
|
if pooled_embedding is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
# TODO(ryand): This should not be hard-coded to None.
|
||||||
|
extra_conditioning=None,
|
||||||
|
pooled_embeds=pooled_embedding,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
), regions
|
||||||
|
return BasicConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
# TODO(ryand): This should not be hard-coded to None.
|
||||||
|
extra_conditioning=None,
|
||||||
|
), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler,
|
|
||||||
unet,
|
unet,
|
||||||
seed,
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
) -> TextConditioningData:
|
) -> TextConditioningData:
|
||||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
self.positive_conditioning, context, unet.device, unet.dtype
|
self.positive_conditioning, context, unet.device, unet.dtype
|
||||||
@ -366,12 +468,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
self.negative_conditioning, context, unet.device, unet.dtype
|
self.negative_conditioning, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
|
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||||
|
text_conditionings=cond_text_embeddings,
|
||||||
|
masks=cond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
conditioning_data = TextConditioningData(
|
conditioning_data = TextConditioningData(
|
||||||
uncond_text_embeddings=uncond_text_embeddings,
|
uncond_text=uncond_text_embedding,
|
||||||
uncond_text_embedding_masks=uncond_text_embedding_masks,
|
cond_text=cond_text_embedding,
|
||||||
cond_text_embeddings=cond_text_embeddings,
|
uncond_regions=uncond_regions,
|
||||||
cond_text_embedding_masks=cond_text_embedding_masks,
|
cond_regions=cond_regions,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
@ -761,7 +874,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
context=context,
|
context=context,
|
||||||
|
@ -411,13 +411,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning
|
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||||
use_cross_attention_control = (
|
use_cross_attention_control = (
|
||||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
)
|
)
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
# HACK(ryand): Fix this logic.
|
# HACK(ryand): Fix this logic.
|
||||||
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1
|
use_regional_prompting = conditioning_data.cond_regions is not None
|
||||||
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
||||||
|
@ -8,6 +8,11 @@ from .cross_attention_control import Arguments
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExtraConditioningInfo:
|
class ExtraConditioningInfo:
|
||||||
|
"""Extra conditioning information produced by Compel.
|
||||||
|
|
||||||
|
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||||
|
"""
|
||||||
|
|
||||||
tokens_count_including_eos_bos: int
|
tokens_count_including_eos_bos: int
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
@ -54,20 +59,48 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextConditioningData:
|
class Range:
|
||||||
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
start: int
|
||||||
uncond_text_embedding_masks: list[Optional[torch.Tensor]]
|
end: int
|
||||||
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
|
||||||
cond_text_embedding_masks: list[Optional[torch.Tensor]]
|
|
||||||
|
|
||||||
"""
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
class TextConditioningRegions:
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
def __init__(self, masks: torch.Tensor, ranges: list[Range]):
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
# Shape: (1, num_prompts, height, width)
|
||||||
"""
|
# Dtype: torch.bool
|
||||||
guidance_scale: Union[float, List[float]]
|
self.masks = masks
|
||||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
|
||||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||||
"""
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
guidance_rescale_multiplier: float = 0
|
self.ranges = ranges
|
||||||
|
|
||||||
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
|
class TextConditioningData:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
|
guidance_scale: Union[float, List[float]],
|
||||||
|
guidance_rescale_multiplier: float = 0,
|
||||||
|
):
|
||||||
|
self.uncond_text = uncond_text
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.uncond_regions = uncond_regions
|
||||||
|
self.cond_regions = cond_regions
|
||||||
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
self.guidance_scale = guidance_scale
|
||||||
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
|
def is_sdxl(self):
|
||||||
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,11 +7,9 @@ from diffusers import UNet2DConditionModel
|
|||||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||||
from diffusers.utils import USE_PEFT_BACKEND
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
@dataclass
|
TextConditioningRegions,
|
||||||
class Range:
|
)
|
||||||
start: int
|
|
||||||
end: int
|
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
@ -20,10 +17,9 @@ class RegionalPromptData:
|
|||||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_masks_and_ranges(
|
def from_regions(
|
||||||
cls,
|
cls,
|
||||||
masks: list[torch.Tensor],
|
regions: list[TextConditioningRegions],
|
||||||
embedding_ranges: list[list[Range]],
|
|
||||||
key_seq_len: int,
|
key_seq_len: int,
|
||||||
# TODO(ryand): Pass in a list of downscale factors?
|
# TODO(ryand): Pass in a list of downscale factors?
|
||||||
max_downscale_factor: int = 8,
|
max_downscale_factor: int = 8,
|
||||||
@ -31,14 +27,8 @@ class RegionalPromptData:
|
|||||||
"""Construct a `RegionalPromptData` object.
|
"""Construct a `RegionalPromptData` object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
|
batch.
|
||||||
prompt should be applied, and 0.0 elsewhere.
|
|
||||||
|
|
||||||
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
|
||||||
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
|
||||||
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
|
||||||
|
|
||||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||||
explicitly to be sure.
|
explicitly to be sure.
|
||||||
@ -48,11 +38,11 @@ class RegionalPromptData:
|
|||||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||||
# length of s.
|
# length of s.
|
||||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True):
|
for batch_sample_regions in regions:
|
||||||
batch_attn_masks_by_seq_len.append({})
|
batch_attn_masks_by_seq_len.append({})
|
||||||
|
|
||||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
batch_masks = batch_masks.to(dtype=torch.float32)
|
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||||
|
|
||||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
downscale_factor = 1
|
downscale_factor = 1
|
||||||
@ -69,7 +59,7 @@ class RegionalPromptData:
|
|||||||
# TODO(ryand): What device / dtype should this be?
|
# TODO(ryand): What device / dtype should this be?
|
||||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||||
|
|
||||||
for prompt_idx, embedding_range in enumerate(batch_ranges):
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||||
:, prompt_idx, :, :
|
:, prompt_idx, :, :
|
||||||
]
|
]
|
||||||
|
@ -5,19 +5,18 @@ from contextlib import contextmanager
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
Range,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -36,143 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RegionalTextConditioningInfo:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
|
||||||
masks: Optional[torch.Tensor] = None,
|
|
||||||
embedding_ranges: Optional[list[Range]] = None,
|
|
||||||
):
|
|
||||||
"""Initialize a RegionalTextConditioningInfo.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
|
|
||||||
after concatenating the embeddings for all regions.
|
|
||||||
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
|
|
||||||
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
|
|
||||||
"""
|
|
||||||
self.text_conditioning = text_conditioning
|
|
||||||
self.masks = masks
|
|
||||||
self.embedding_ranges = embedding_ranges
|
|
||||||
|
|
||||||
assert (self.masks is None) == (self.embedding_ranges is None)
|
|
||||||
if self.masks is not None:
|
|
||||||
assert self.masks.shape[1] == len(self.embedding_ranges)
|
|
||||||
|
|
||||||
def has_region_masks(self):
|
|
||||||
if self.masks is None:
|
|
||||||
return False
|
|
||||||
return any(mask is not None for mask in self.masks)
|
|
||||||
|
|
||||||
def is_sdxl(self):
|
|
||||||
return isinstance(self.text_conditioning, SDXLConditioningInfo)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _preprocess_regional_prompt_mask(
|
|
||||||
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Preprocess a regional prompt mask to match the target height and width.
|
|
||||||
|
|
||||||
If mask is None, returns a mask of all ones with the target height and width.
|
|
||||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
|
||||||
"""
|
|
||||||
if mask is None:
|
|
||||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
|
||||||
|
|
||||||
tf = torchvision.transforms.Resize(
|
|
||||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
|
||||||
)
|
|
||||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
|
||||||
mask = tf(mask)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_text_conditioning_and_masks(
|
|
||||||
cls,
|
|
||||||
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
|
||||||
masks: Optional[list[Optional[torch.Tensor]]],
|
|
||||||
latent_height: int,
|
|
||||||
latent_width: int,
|
|
||||||
):
|
|
||||||
if masks is None:
|
|
||||||
masks = [None] * len(text_conditionings)
|
|
||||||
assert len(text_conditionings) == len(masks)
|
|
||||||
|
|
||||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
|
||||||
|
|
||||||
all_masks_are_none = all(mask is None for mask in masks)
|
|
||||||
|
|
||||||
text_embedding = []
|
|
||||||
pooled_embedding = None
|
|
||||||
add_time_ids = None
|
|
||||||
processed_masks = []
|
|
||||||
cur_text_embedding_len = 0
|
|
||||||
embedding_ranges: list[Range] = []
|
|
||||||
|
|
||||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
|
||||||
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
|
||||||
assert (
|
|
||||||
text_embedding_info.extra_conditioning is None
|
|
||||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
|
||||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
|
||||||
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
|
||||||
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
|
||||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
|
||||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
|
||||||
#
|
|
||||||
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
|
||||||
# most likely to be a global prompt.
|
|
||||||
if pooled_embedding is None:
|
|
||||||
pooled_embedding = text_embedding_info.pooled_embeds
|
|
||||||
if add_time_ids is None:
|
|
||||||
add_time_ids = text_embedding_info.add_time_ids
|
|
||||||
|
|
||||||
text_embedding.append(text_embedding_info.embeds)
|
|
||||||
embedding_ranges.append(
|
|
||||||
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
|
||||||
)
|
|
||||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
|
||||||
|
|
||||||
if not all_masks_are_none:
|
|
||||||
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
|
||||||
|
|
||||||
text_embedding = torch.cat(text_embedding, dim=1)
|
|
||||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
|
||||||
|
|
||||||
if not all_masks_are_none:
|
|
||||||
processed_masks = torch.cat(processed_masks, dim=1)
|
|
||||||
else:
|
|
||||||
processed_masks = None
|
|
||||||
embedding_ranges = None
|
|
||||||
|
|
||||||
if is_sdxl:
|
|
||||||
return cls(
|
|
||||||
text_conditioning=SDXLConditioningInfo(
|
|
||||||
embeds=text_embedding,
|
|
||||||
extra_conditioning=None,
|
|
||||||
pooled_embeds=pooled_embedding,
|
|
||||||
add_time_ids=add_time_ids,
|
|
||||||
),
|
|
||||||
masks=processed_masks,
|
|
||||||
embedding_ranges=embedding_ranges,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
text_conditioning=BasicConditioningInfo(
|
|
||||||
embeds=text_embedding,
|
|
||||||
extra_conditioning=None,
|
|
||||||
),
|
|
||||||
masks=processed_masks,
|
|
||||||
embedding_ranges=embedding_ranges,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
"""
|
"""
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
@ -233,10 +95,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning_data: TextConditioningData,
|
conditioning_data: TextConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
|
||||||
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
|
||||||
uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
|
|
||||||
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
|
|
||||||
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
@ -267,25 +125,30 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": cond_text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": cond_text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = cond_text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0
|
[
|
||||||
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0
|
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||||
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||||
uncond_text_embeddings.embeds, cond_text_embeddings.embeds
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -440,52 +303,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
|
|
||||||
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
|
|
||||||
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
|
|
||||||
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
|
|
||||||
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
|
|
||||||
# stack.
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
_, _, h, w = x.shape
|
|
||||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
|
||||||
text_conditionings=conditioning_data.cond_text_embeddings,
|
|
||||||
masks=conditioning_data.cond_text_embedding_masks,
|
|
||||||
latent_height=h,
|
|
||||||
latent_width=w,
|
|
||||||
)
|
|
||||||
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
|
||||||
text_conditionings=conditioning_data.uncond_text_embeddings,
|
|
||||||
masks=conditioning_data.uncond_text_embedding_masks,
|
|
||||||
latent_height=h,
|
|
||||||
latent_width=w,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cond_text.has_region_masks() or uncond_text.has_region_masks():
|
|
||||||
masks = []
|
|
||||||
embedding_ranges = []
|
|
||||||
for c in [uncond_text, cond_text]:
|
|
||||||
if c.has_region_masks():
|
|
||||||
masks.append(c.masks)
|
|
||||||
embedding_ranges.append(c.embedding_ranges)
|
|
||||||
else:
|
|
||||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
|
||||||
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
|
|
||||||
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
|
|
||||||
|
|
||||||
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
|
|
||||||
# embeddings will be padded to match this length.
|
|
||||||
key_seq_len = 0
|
|
||||||
for c in [uncond_text, cond_text]:
|
|
||||||
_, seq_len, _ = c.text_conditioning.embeds.shape
|
|
||||||
if seq_len > key_seq_len:
|
|
||||||
key_seq_len = seq_len
|
|
||||||
|
|
||||||
cross_attention_kwargs = {
|
|
||||||
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
|
|
||||||
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
||||||
if ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
@ -499,20 +317,45 @@ class InvokeAIDiffuserComponent:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uncond_text = conditioning_data.uncond_text
|
||||||
|
cond_text = conditioning_data.cond_text
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if cond_text.is_sdxl():
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||||
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0
|
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||||
),
|
|
||||||
"time_ids": torch.cat(
|
|
||||||
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds
|
uncond_text.embeds, cond_text.embeds
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
|
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
|
||||||
|
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
|
regions = []
|
||||||
|
for c, r in [
|
||||||
|
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||||
|
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||||
|
]:
|
||||||
|
if r is None:
|
||||||
|
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
r = TextConditioningRegions(
|
||||||
|
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||||
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||||
|
)
|
||||||
|
regions.append(r)
|
||||||
|
|
||||||
|
_, key_seq_len, _ = both_conditionings.shape
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len)
|
||||||
|
}
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
@ -542,9 +385,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(conditioning_data.cond_text_embeddings) == 1
|
|
||||||
text_embeddings = conditioning_data.cond_text_embeddings[0]
|
|
||||||
|
|
||||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||||
# and T2I-Adapter residuals into two chunks.
|
# and T2I-Adapter residuals into two chunks.
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
@ -602,18 +442,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
|
if conditioning_data.is_sdxl():
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.uncond_text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.uncond_text_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
@ -644,17 +483,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user