This commit is contained in:
Ryan Dick 2024-04-20 17:09:41 -04:00
parent d582203c62
commit d183aa823c
6 changed files with 155 additions and 93 deletions

View File

@ -57,10 +57,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
IPAdapterData, IPAdapterData,
Range, SDRegionalTextConditioning,
SDXLConditioningInfo, SDXLConditioningInfo,
SDXLRegionalTextConditioning,
TextConditioningData, TextConditioningData,
TextConditioningRegions,
) )
from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.mask import to_standard_float_mask
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
@ -408,19 +408,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
resized_mask = tf(mask) resized_mask = tf(mask)
return resized_mask return resized_mask
def _concat_regional_text_embeddings( def _prepare_regional_text_embeddings(
self, self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]], masks: list[Optional[torch.Tensor]],
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]: ) -> Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning]:
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly.""" """Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks) all_masks_are_none = all(mask is None for mask in masks)
@ -428,9 +424,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embedding = [] text_embedding = []
pooled_embedding = None pooled_embedding = None
add_time_ids = None add_time_ids = None
cur_text_embedding_len = 0
processed_masks = [] processed_masks = []
embedding_ranges = []
for prompt_idx, text_embedding_info in enumerate(text_conditionings): for prompt_idx, text_embedding_info in enumerate(text_conditionings):
mask = masks[prompt_idx] mask = masks[prompt_idx]
@ -453,32 +447,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
text_embedding.append(text_embedding_info.embeds) text_embedding.append(text_embedding_info.embeds)
if not all_masks_are_none: if not all_masks_are_none:
embedding_ranges.append(
Range(
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
)
)
processed_masks.append( processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
) )
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regions = None
if not all_masks_are_none:
regions = TextConditioningRegions(
masks=torch.cat(processed_masks, dim=1),
ranges=embedding_ranges,
)
if is_sdxl: if is_sdxl:
return SDXLConditioningInfo( return SDXLRegionalTextConditioning(
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids pooled_embeds=pooled_embedding,
), regions add_time_ids=add_time_ids,
return BasicConditioningInfo(embeds=text_embedding), regions text_embeds=text_embedding,
masks=None if all_masks_are_none else processed_masks,
)
return SDRegionalTextConditioning(
text_embeds=text_embedding,
masks=None if all_masks_are_none else processed_masks,
)
def get_conditioning_data( def get_conditioning_data(
self, self,
@ -502,14 +485,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_list, context, unet.device, unet.dtype uncond_list, context, unet.device, unet.dtype
) )
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( cond_text = self._prepare_regional_text_embeddings(
text_conditionings=cond_text_embeddings, text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype, dtype=unet.dtype,
) )
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( uncond_text = self._prepare_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
@ -518,10 +501,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding, uncond_text=uncond_text,
cond_text=cond_text_embedding, cond_text=cond_text,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
) )

View File

@ -386,7 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
use_ip_adapter = ip_adapter_data is not None use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = ( use_regional_prompting = (
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts()
) )
unet_attention_patcher = None unet_attention_patcher = None
self.use_ip_adapter = use_ip_adapter self.use_ip_adapter = use_ip_adapter

View File

@ -95,20 +95,50 @@ class TextConditioningRegions:
assert self.masks.shape[1] == len(self.ranges) assert self.masks.shape[1] == len(self.ranges)
class SDRegionalTextConditioning:
def __init__(self, text_embeds: list[torch.Tensor], masks: Optional[list[torch.Tensor]]):
if masks is not None:
assert len(text_embeds) == len(masks)
# A list of text embeddings. text_embeds[i] contains the text embeddings for the i'th prompt.
self.text_embeds = text_embeds
# A list of masks indicating the regions of the image that the prompts should be applied to. masks[i] contains
# the mask for the i'th prompt. Each mask has shape (1, height, width).
self.masks = masks
def uses_regional_prompts(self):
# If there is more than one prompt, we treat this as regional prompting, even if there are no masks, because
# the regional prompting logic is used to combine the information from multiple prompts.
return len(self.text_embeds) > 1 or self.masks is not None
class SDXLRegionalTextConditioning(SDRegionalTextConditioning):
def __init__(
self,
pooled_embeds: torch.Tensor,
add_time_ids: torch.Tensor,
text_embeds: list[torch.Tensor],
masks: Optional[list[torch.Tensor]],
):
super().__init__(text_embeds, masks)
# Pooled embeddings for the global prompt.
self.pooled_embeds = pooled_embeds
# Additional global conditioning inputs for SDXL. The name "time_ids" comes from diffusers, and is a bit of a
# misnomer. This Tensor contains original_size, crop_coords, and target_size conditioning.
self.add_time_ids = add_time_ids
class TextConditioningData: class TextConditioningData:
def __init__( def __init__(
self, self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], uncond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo], cond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]], guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0, guidance_rescale_multiplier: float = 0,
): ):
self.uncond_text = uncond_text self.uncond_text = uncond_text
self.cond_text = cond_text self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
@ -119,5 +149,7 @@ class TextConditioningData:
self.guidance_rescale_multiplier = guidance_rescale_multiplier self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self): def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) assert isinstance(self.uncond_text, SDXLRegionalTextConditioning) == isinstance(
return isinstance(self.cond_text, SDXLConditioningInfo) self.cond_text, SDXLRegionalTextConditioning
)
return isinstance(self.cond_text, SDXLRegionalTextConditioning)

View File

@ -63,6 +63,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# If true, we are doing cross-attention, if false we are doing self-attention. # If true, we are doing cross-attention, if false we are doing self-attention.
is_cross_attention = encoder_hidden_states is not None is_cross_attention = encoder_hidden_states is not None
_, query_seq_len, _ = hidden_states.shape
if regional_prompt_data is not None and is_cross_attention:
assert percent_through is not None
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
encoder_hidden_states = regional_prompt_data.text_embeds
# Start unmodified block from AttnProcessor2_0. # Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states residual = hidden_states
@ -81,18 +87,26 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0. # End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape # Current:
# Handle regional prompt attention masks. # - Run attention once, with masking to control which tokens each pixel is *allowed* to pay attention to.
if regional_prompt_data is not None and is_cross_attention: # New
assert percent_through is not None # - Run attention on each prompt separately. (no masking)
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask( # - Combine the results with a weighted sum.
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
if attention_mask is None: # _, query_seq_len, _ = hidden_states.shape
attention_mask = prompt_region_attention_mask # Handle regional prompt attention masks.
else: # if regional_prompt_data is not None and is_cross_attention:
attention_mask = prompt_region_attention_mask + attention_mask # assert percent_through is not None
# prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
# prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
# query_seq_len=query_seq_len, key_seq_len=sequence_length
# )
# if attention_mask is None:
# attention_mask = prompt_region_attention_mask
# else:
# attention_mask = prompt_region_attention_mask + attention_mask
# Start unmodified block from AttnProcessor2_0. # Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv

View File

@ -1,9 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningRegions
TextConditioningRegions,
)
class RegionalPromptData: class RegionalPromptData:
@ -11,31 +9,71 @@ class RegionalPromptData:
def __init__( def __init__(
self, self,
regions: list[TextConditioningRegions], text_embeds: list[list[torch.Tensor]],
masks: list[list[torch.Tensor]],
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
max_downscale_factor: int = 8, max_downscale_factor: int = 8,
): ):
"""Initialize a `RegionalPromptData` object. """Initialize a `RegionalPromptData` object.
Args: Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the TODO(ryand): Update these docs.
batch.
text_embeds (list[list[torch.Tensor]]): The text prompt embeddings. text_embeds[b][i] contains the embedding
for prompt i to be applied to batch image b.
masks (list[list[torch.Tensor]]): The masks indicating the spatial regions of the image that each prompt
applies to. masks[b][i] contains the mask for text_embeds[b][i].
device (torch.device): The device to use for the attention masks. device (torch.device): The device to use for the attention masks.
dtype (torch.dtype): The data type to use for the attention masks. dtype (torch.dtype): The data type to use for the attention masks.
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
in steps of 2x. in steps of 2x.
""" """
self._regions = regions
assert len(text_embeds) == len(masks)
for text_embeds_batch, masks_batch in zip(text_embeds, masks, strict=True):
assert len(text_embeds_batch) == len(masks_batch)
self.prompt_count_by_batch_element = [len(text_embeds_batch) for text_embeds_batch in text_embeds]
# Flattenand concat text_embeds.
text_embeds_flat_list: list[torch.Tensor] = []
for text_embeds_batch in text_embeds:
text_embeds_flat_list.extend(text_embeds_batch)
# TODO(ryand): Or stack?
# TODO(ryand): Text embeds might not all be the same size (if there were long prompts).
self.text_embeds = torch.cat(text_embeds_flat_list, dim=0)
# Flatten and concat masks.
masks_flat_list = []
for mask_batch in masks:
masks_flat_list.extend(mask_batch)
self._masks = torch.cat(masks_flat_list, dim=0)
self._device = device self._device = device
self._dtype = dtype self._dtype = dtype
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
self._negative_cross_attn_mask_score = -10000.0
def _prepare_spatial_masks( def get_masks(self, query_seq_len: int):
_, h, w = self._masks.shape
# Determine the downscaling factor for the given query sequence length.
max_downscale_factor = 8
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
break
downscale_factor *= 2
if query_seq_len != (h // downscale_factor) * (w // downscale_factor):
raise ValueError(f"Failed to find a mask downsampling factor for query sequence length: {query_seq_len}")
target_h = h // downscale_factor
target_w = w // downscale_factor
mask_downscaled = torch.nn.functional.interpolate(self._masks, size=(target_h, target_w), mode="nearest")
return mask_downscaled
def _prepare_spatial_masks_old(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8 self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]: ) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors.""" """Prepare the spatial masks for all downscaling factors."""

View File

@ -7,12 +7,7 @@ import torch
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
IPAdapterData,
Range,
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@ -312,33 +307,35 @@ class InvokeAIDiffuserComponent:
), ),
} }
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: if conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts():
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack. # awkward to handle both standard conditioning and sequential conditioning further up the stack.
regions = [] masks: list[list[torch.Tensor]] = []
for c, r in [ for text_conditioning in [conditioning_data.uncond_text, conditioning_data.cond_text]:
(conditioning_data.uncond_text, conditioning_data.uncond_regions), if text_conditioning.masks is None:
(conditioning_data.cond_text, conditioning_data.cond_regions), # Create a dummy mask for text conditioning that doesn't have region masks.
]:
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
_, _, h, w = x.shape _, _, h, w = x.shape
r = TextConditioningRegions( masks.append([torch.ones((1, 1, h, w), dtype=x.dtype)] * len(text_conditioning.text_embeds))
masks=torch.ones((1, 1, h, w), dtype=x.dtype), else:
ranges=[Range(start=0, end=c.embeds.shape[1])], masks.append(text_conditioning.masks)
)
regions.append(r)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
regions=regions, device=x.device, dtype=x.dtype text_embeds=[conditioning_data.uncond_text.text_embeds, conditioning_data.cond_text.text_embeds],
masks=masks,
device=x.device,
dtype=x.dtype,
) )
cross_attention_kwargs["percent_through"] = step_index / total_step_count cross_attention_kwargs["percent_through"] = step_index / total_step_count
# Note: We pass in the *first* text_embeds entry for both unconditioned and conditioned text embeds. This is the
# desired behaviour under 'normal' conditions when there is a single text prompt. In cases where we are doing
# regional prompting with multiple prompts, this input will be ignored altogether and the prompt information
# will be passed via the RegionalPromptData object.
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds conditioning_data.uncond_text.text_embeds[0], conditioning_data.cond_text.text_embeds[0]
) )
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,