mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move regional prompt concatenation further up the stack. This solves a number of issues.
This commit is contained in:
parent
53ebca58ff
commit
5f49e7ae26
@ -1,5 +1,4 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import ExitStack
|
||||
@ -9,6 +8,7 @@ from typing import List, Literal, Optional, Union
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as T
|
||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@ -44,8 +44,10 @@ from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
Range,
|
||||
SDXLConditioningInfo,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
@ -334,7 +336,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||
# Normalize cond_field to a list.
|
||||
cond_list = cond_field
|
||||
if not isinstance(cond_list, list):
|
||||
@ -353,12 +356,111 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
|
||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||
# HACK(ryand): Figure out the intended relationship with CAC. Probably want to raise if more than one text
|
||||
# embedding is passed in and CAC is being used.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
||||
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
#
|
||||
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
||||
# most likely to be a global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
)
|
||||
)
|
||||
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regions = None
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(masks=torch.cat(processed_masks, dim=1), ranges=embedding_ranges)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
# TODO(ryand): This should not be hard-coded to None.
|
||||
extra_conditioning=None,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
), regions
|
||||
return BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
# TODO(ryand): This should not be hard-coded to None.
|
||||
extra_conditioning=None,
|
||||
), regions
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
scheduler,
|
||||
unet,
|
||||
seed,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
) -> TextConditioningData:
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
self.positive_conditioning, context, unet.device, unet.dtype
|
||||
@ -366,12 +468,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
self.negative_conditioning, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self.concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
)
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text_embeddings=uncond_text_embeddings,
|
||||
uncond_text_embedding_masks=uncond_text_embedding_masks,
|
||||
cond_text_embeddings=cond_text_embeddings,
|
||||
cond_text_embedding_masks=cond_text_embedding_masks,
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
@ -761,7 +874,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
|
@ -411,13 +411,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning
|
||||
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||
use_cross_attention_control = (
|
||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||
)
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
# HACK(ryand): Fix this logic.
|
||||
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1
|
||||
use_regional_prompting = conditioning_data.cond_regions is not None
|
||||
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
||||
raise Exception(
|
||||
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
||||
|
@ -8,6 +8,11 @@ from .cross_attention_control import Arguments
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
"""Extra conditioning information produced by Compel.
|
||||
|
||||
This is used for prompt-to-prompt cross-attention control (a.k.a. `.swap()` in Compel).
|
||||
"""
|
||||
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@ -54,20 +59,48 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextConditioningData:
|
||||
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
||||
uncond_text_embedding_masks: list[Optional[torch.Tensor]]
|
||||
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
||||
cond_text_embedding_masks: list[Optional[torch.Tensor]]
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
|
||||
class TextConditioningRegions:
|
||||
def __init__(self, masks: torch.Tensor, ranges: list[Range]):
|
||||
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||
# Shape: (1, num_prompts, height, width)
|
||||
# Dtype: torch.bool
|
||||
self.masks = masks
|
||||
|
||||
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||
self.ranges = ranges
|
||||
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
@ -1,5 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -8,11 +7,9 @@ from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
@dataclass
|
||||
class Range:
|
||||
start: int
|
||||
end: int
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
@ -20,10 +17,9 @@ class RegionalPromptData:
|
||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
||||
|
||||
@classmethod
|
||||
def from_masks_and_ranges(
|
||||
def from_regions(
|
||||
cls,
|
||||
masks: list[torch.Tensor],
|
||||
embedding_ranges: list[list[Range]],
|
||||
regions: list[TextConditioningRegions],
|
||||
key_seq_len: int,
|
||||
# TODO(ryand): Pass in a list of downscale factors?
|
||||
max_downscale_factor: int = 8,
|
||||
@ -31,14 +27,8 @@ class RegionalPromptData:
|
||||
"""Construct a `RegionalPromptData` object.
|
||||
|
||||
Args:
|
||||
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
||||
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
|
||||
prompt should be applied, and 0.0 elsewhere.
|
||||
|
||||
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
||||
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
||||
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
||||
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||
explicitly to be sure.
|
||||
@ -48,11 +38,11 @@ class RegionalPromptData:
|
||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||
# length of s.
|
||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True):
|
||||
for batch_sample_regions in regions:
|
||||
batch_attn_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_masks = batch_masks.to(dtype=torch.float32)
|
||||
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
@ -69,7 +59,7 @@ class RegionalPromptData:
|
||||
# TODO(ryand): What device / dtype should this be?
|
||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_ranges):
|
||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||
:, prompt_idx, :, :
|
||||
]
|
||||
|
@ -5,19 +5,18 @@ from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ExtraConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import RegionalPromptData
|
||||
|
||||
from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
@ -36,143 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
]
|
||||
|
||||
|
||||
class RegionalTextConditioningInfo:
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
embedding_ranges: Optional[list[Range]] = None,
|
||||
):
|
||||
"""Initialize a RegionalTextConditioningInfo.
|
||||
|
||||
Args:
|
||||
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
|
||||
after concatenating the embeddings for all regions.
|
||||
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
|
||||
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
|
||||
"""
|
||||
self.text_conditioning = text_conditioning
|
||||
self.masks = masks
|
||||
self.embedding_ranges = embedding_ranges
|
||||
|
||||
assert (self.masks is None) == (self.embedding_ranges is None)
|
||||
if self.masks is not None:
|
||||
assert self.masks.shape[1] == len(self.embedding_ranges)
|
||||
|
||||
def has_region_masks(self):
|
||||
if self.masks is None:
|
||||
return False
|
||||
return any(mask is not None for mask in self.masks)
|
||||
|
||||
def is_sdxl(self):
|
||||
return isinstance(self.text_conditioning, SDXLConditioningInfo)
|
||||
|
||||
@classmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning_and_masks(
|
||||
cls,
|
||||
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
):
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
processed_masks = []
|
||||
cur_text_embedding_len = 0
|
||||
embedding_ranges: list[Range] = []
|
||||
|
||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
||||
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
#
|
||||
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
||||
# most likely to be a global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
embedding_ranges.append(
|
||||
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
||||
)
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
if not all_masks_are_none:
|
||||
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
if not all_masks_are_none:
|
||||
processed_masks = torch.cat(processed_masks, dim=1)
|
||||
else:
|
||||
processed_masks = None
|
||||
embedding_ranges = None
|
||||
|
||||
if is_sdxl:
|
||||
return cls(
|
||||
text_conditioning=SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
),
|
||||
masks=processed_masks,
|
||||
embedding_ranges=embedding_ranges,
|
||||
)
|
||||
return cls(
|
||||
text_conditioning=BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
),
|
||||
masks=processed_masks,
|
||||
embedding_ranges=embedding_ranges,
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@ -233,10 +95,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning_data: TextConditioningData,
|
||||
):
|
||||
down_block_res_samples, mid_block_res_sample = None, None
|
||||
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
||||
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
||||
uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
|
||||
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
|
||||
|
||||
# control_data should be type List[ControlNetData]
|
||||
# this loop covers both ControlNet (one ControlNetData in list)
|
||||
@ -267,25 +125,30 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": cond_text_embeddings.pooled_embeds,
|
||||
"time_ids": cond_text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
encoder_hidden_states = cond_text_embeddings.embeds
|
||||
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0
|
||||
[
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0
|
||||
[conditioning_data.uncond_text.add_time_ids, conditioning_data.cond_text.add_time_ids],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||
uncond_text_embeddings.embeds, cond_text_embeddings.embeds
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
)
|
||||
if isinstance(control_datum.weight, list):
|
||||
# if controlnet has multiple weights, use the weight for the current step
|
||||
@ -440,52 +303,7 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
|
||||
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
|
||||
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
|
||||
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
|
||||
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
|
||||
# stack.
|
||||
cross_attention_kwargs = None
|
||||
_, _, h, w = x.shape
|
||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||
text_conditionings=conditioning_data.cond_text_embeddings,
|
||||
masks=conditioning_data.cond_text_embedding_masks,
|
||||
latent_height=h,
|
||||
latent_width=w,
|
||||
)
|
||||
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||
text_conditionings=conditioning_data.uncond_text_embeddings,
|
||||
masks=conditioning_data.uncond_text_embedding_masks,
|
||||
latent_height=h,
|
||||
latent_width=w,
|
||||
)
|
||||
|
||||
if cond_text.has_region_masks() or uncond_text.has_region_masks():
|
||||
masks = []
|
||||
embedding_ranges = []
|
||||
for c in [uncond_text, cond_text]:
|
||||
if c.has_region_masks():
|
||||
masks.append(c.masks)
|
||||
embedding_ranges.append(c.embedding_ranges)
|
||||
else:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
|
||||
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
|
||||
|
||||
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
|
||||
# embeddings will be padded to match this length.
|
||||
key_seq_len = 0
|
||||
for c in [uncond_text, cond_text]:
|
||||
_, seq_len, _ = c.text_conditioning.embeds.shape
|
||||
if seq_len > key_seq_len:
|
||||
key_seq_len = seq_len
|
||||
|
||||
cross_attention_kwargs = {
|
||||
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
|
||||
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
|
||||
)
|
||||
}
|
||||
|
||||
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
||||
if ip_adapter_conditioning is not None:
|
||||
@ -499,20 +317,45 @@ class InvokeAIDiffuserComponent:
|
||||
]
|
||||
}
|
||||
|
||||
uncond_text = conditioning_data.uncond_text
|
||||
cond_text = conditioning_data.cond_text
|
||||
|
||||
added_cond_kwargs = None
|
||||
if cond_text.is_sdxl():
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
|
||||
),
|
||||
"text_embeds": torch.cat([uncond_text.pooled_embeds, cond_text.pooled_embeds], dim=0),
|
||||
"time_ids": torch.cat([uncond_text.add_time_ids, cond_text.add_time_ids], dim=0),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds
|
||||
uncond_text.embeds, cond_text.embeds
|
||||
)
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently call from_regions(...) for every denoising step. The text conditionings and
|
||||
# masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=torch.bool),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
_, key_seq_len, _ = both_conditionings.shape
|
||||
cross_attention_kwargs = {
|
||||
"regional_prompt_data": RegionalPromptData.from_regions(regions=regions, key_seq_len=key_seq_len)
|
||||
}
|
||||
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
@ -542,9 +385,6 @@ class InvokeAIDiffuserComponent:
|
||||
slower execution speed.
|
||||
"""
|
||||
|
||||
assert len(conditioning_data.cond_text_embeddings) == 1
|
||||
text_embeddings = conditioning_data.cond_text_embeddings[0]
|
||||
|
||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||
# and T2I-Adapter residuals into two chunks.
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
@ -602,18 +442,17 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||
unconditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.uncond_text_embeddings.embeds,
|
||||
conditioning_data.uncond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
@ -644,17 +483,17 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeddings.pooled_embeds,
|
||||
"time_ids": text_embeddings.add_time_ids,
|
||||
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||
}
|
||||
|
||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x,
|
||||
sigma,
|
||||
text_embeddings.embeds,
|
||||
conditioning_data.cond_text.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
|
Loading…
Reference in New Issue
Block a user