mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
d582203c62
commit
d183aa823c
@ -57,10 +57,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
IPAdapterData,
|
||||
Range,
|
||||
SDRegionalTextConditioning,
|
||||
SDXLConditioningInfo,
|
||||
SDXLRegionalTextConditioning,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
@ -408,19 +408,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
resized_mask = tf(mask)
|
||||
return resized_mask
|
||||
|
||||
def _concat_regional_text_embeddings(
|
||||
def _prepare_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
masks: list[Optional[torch.Tensor]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||
) -> Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning]:
|
||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
@ -428,9 +424,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
cur_text_embedding_len = 0
|
||||
processed_masks = []
|
||||
embedding_ranges = []
|
||||
|
||||
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||
mask = masks[prompt_idx]
|
||||
@ -453,32 +447,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
if not all_masks_are_none:
|
||||
embedding_ranges.append(
|
||||
Range(
|
||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
regions = None
|
||||
if not all_masks_are_none:
|
||||
regions = TextConditioningRegions(
|
||||
masks=torch.cat(processed_masks, dim=1),
|
||||
ranges=embedding_ranges,
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
return SDXLConditioningInfo(
|
||||
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
||||
), regions
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
return SDXLRegionalTextConditioning(
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
text_embeds=text_embedding,
|
||||
masks=None if all_masks_are_none else processed_masks,
|
||||
)
|
||||
return SDRegionalTextConditioning(
|
||||
text_embeds=text_embedding,
|
||||
masks=None if all_masks_are_none else processed_masks,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
@ -502,14 +485,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
cond_text = self._prepare_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
uncond_text = self._prepare_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
@ -518,10 +501,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
uncond_text=uncond_text,
|
||||
cond_text=cond_text,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
@ -386,7 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts()
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
|
@ -95,20 +95,50 @@ class TextConditioningRegions:
|
||||
assert self.masks.shape[1] == len(self.ranges)
|
||||
|
||||
|
||||
class SDRegionalTextConditioning:
|
||||
def __init__(self, text_embeds: list[torch.Tensor], masks: Optional[list[torch.Tensor]]):
|
||||
if masks is not None:
|
||||
assert len(text_embeds) == len(masks)
|
||||
|
||||
# A list of text embeddings. text_embeds[i] contains the text embeddings for the i'th prompt.
|
||||
self.text_embeds = text_embeds
|
||||
# A list of masks indicating the regions of the image that the prompts should be applied to. masks[i] contains
|
||||
# the mask for the i'th prompt. Each mask has shape (1, height, width).
|
||||
self.masks = masks
|
||||
|
||||
def uses_regional_prompts(self):
|
||||
# If there is more than one prompt, we treat this as regional prompting, even if there are no masks, because
|
||||
# the regional prompting logic is used to combine the information from multiple prompts.
|
||||
return len(self.text_embeds) > 1 or self.masks is not None
|
||||
|
||||
|
||||
class SDXLRegionalTextConditioning(SDRegionalTextConditioning):
|
||||
def __init__(
|
||||
self,
|
||||
pooled_embeds: torch.Tensor,
|
||||
add_time_ids: torch.Tensor,
|
||||
text_embeds: list[torch.Tensor],
|
||||
masks: Optional[list[torch.Tensor]],
|
||||
):
|
||||
super().__init__(text_embeds, masks)
|
||||
|
||||
# Pooled embeddings for the global prompt.
|
||||
self.pooled_embeds = pooled_embeds
|
||||
# Additional global conditioning inputs for SDXL. The name "time_ids" comes from diffusers, and is a bit of a
|
||||
# misnomer. This Tensor contains original_size, crop_coords, and target_size conditioning.
|
||||
self.add_time_ids = add_time_ids
|
||||
|
||||
|
||||
class TextConditioningData:
|
||||
def __init__(
|
||||
self,
|
||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
uncond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
|
||||
cond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
self.uncond_regions = uncond_regions
|
||||
self.cond_regions = cond_regions
|
||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
@ -119,5 +149,7 @@ class TextConditioningData:
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
assert isinstance(self.uncond_text, SDXLRegionalTextConditioning) == isinstance(
|
||||
self.cond_text, SDXLRegionalTextConditioning
|
||||
)
|
||||
return isinstance(self.cond_text, SDXLRegionalTextConditioning)
|
||||
|
@ -63,6 +63,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
if regional_prompt_data is not None and is_cross_attention:
|
||||
assert percent_through is not None
|
||||
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||
encoder_hidden_states = regional_prompt_data.text_embeds
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
residual = hidden_states
|
||||
@ -81,18 +87,26 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
# End unmodified block from AttnProcessor2_0.
|
||||
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
# Handle regional prompt attention masks.
|
||||
if regional_prompt_data is not None and is_cross_attention:
|
||||
assert percent_through is not None
|
||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
)
|
||||
# Current:
|
||||
# - Run attention once, with masking to control which tokens each pixel is *allowed* to pay attention to.
|
||||
# New
|
||||
# - Run attention on each prompt separately. (no masking)
|
||||
# - Combine the results with a weighted sum.
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = prompt_region_attention_mask
|
||||
else:
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
# _, query_seq_len, _ = hidden_states.shape
|
||||
# Handle regional prompt attention masks.
|
||||
# if regional_prompt_data is not None and is_cross_attention:
|
||||
# assert percent_through is not None
|
||||
# prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
# prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||
# query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||
# )
|
||||
|
||||
# if attention_mask is None:
|
||||
# attention_mask = prompt_region_attention_mask
|
||||
# else:
|
||||
# attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
# Start unmodified block from AttnProcessor2_0.
|
||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||
|
@ -1,9 +1,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningRegions
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
@ -11,31 +9,71 @@ class RegionalPromptData:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
regions: list[TextConditioningRegions],
|
||||
text_embeds: list[list[torch.Tensor]],
|
||||
masks: list[list[torch.Tensor]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Initialize a `RegionalPromptData` object.
|
||||
Args:
|
||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||
batch.
|
||||
TODO(ryand): Update these docs.
|
||||
|
||||
text_embeds (list[list[torch.Tensor]]): The text prompt embeddings. text_embeds[b][i] contains the embedding
|
||||
for prompt i to be applied to batch image b.
|
||||
masks (list[list[torch.Tensor]]): The masks indicating the spatial regions of the image that each prompt
|
||||
applies to. masks[b][i] contains the mask for text_embeds[b][i].
|
||||
|
||||
device (torch.device): The device to use for the attention masks.
|
||||
dtype (torch.dtype): The data type to use for the attention masks.
|
||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||
in steps of 2x.
|
||||
"""
|
||||
self._regions = regions
|
||||
|
||||
assert len(text_embeds) == len(masks)
|
||||
for text_embeds_batch, masks_batch in zip(text_embeds, masks, strict=True):
|
||||
assert len(text_embeds_batch) == len(masks_batch)
|
||||
|
||||
self.prompt_count_by_batch_element = [len(text_embeds_batch) for text_embeds_batch in text_embeds]
|
||||
|
||||
# Flattenand concat text_embeds.
|
||||
text_embeds_flat_list: list[torch.Tensor] = []
|
||||
for text_embeds_batch in text_embeds:
|
||||
text_embeds_flat_list.extend(text_embeds_batch)
|
||||
# TODO(ryand): Or stack?
|
||||
# TODO(ryand): Text embeds might not all be the same size (if there were long prompts).
|
||||
self.text_embeds = torch.cat(text_embeds_flat_list, dim=0)
|
||||
|
||||
# Flatten and concat masks.
|
||||
masks_flat_list = []
|
||||
for mask_batch in masks:
|
||||
masks_flat_list.extend(mask_batch)
|
||||
self._masks = torch.cat(masks_flat_list, dim=0)
|
||||
|
||||
self._device = device
|
||||
self._dtype = dtype
|
||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||
# sequence length of s.
|
||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||
regions, max_downscale_factor
|
||||
)
|
||||
self._negative_cross_attn_mask_score = -10000.0
|
||||
|
||||
def _prepare_spatial_masks(
|
||||
def get_masks(self, query_seq_len: int):
|
||||
_, h, w = self._masks.shape
|
||||
|
||||
# Determine the downscaling factor for the given query sequence length.
|
||||
max_downscale_factor = 8
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
|
||||
break
|
||||
downscale_factor *= 2
|
||||
|
||||
if query_seq_len != (h // downscale_factor) * (w // downscale_factor):
|
||||
raise ValueError(f"Failed to find a mask downsampling factor for query sequence length: {query_seq_len}")
|
||||
|
||||
target_h = h // downscale_factor
|
||||
target_w = w // downscale_factor
|
||||
mask_downscaled = torch.nn.functional.interpolate(self._masks, size=(target_h, target_w), mode="nearest")
|
||||
|
||||
return mask_downscaled
|
||||
|
||||
def _prepare_spatial_masks_old(
|
||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||
) -> list[dict[int, torch.Tensor]]:
|
||||
"""Prepare the spatial masks for all downscaling factors."""
|
||||
|
@ -7,12 +7,7 @@ import torch
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
Range,
|
||||
TextConditioningData,
|
||||
TextConditioningRegions,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
@ -312,33 +307,35 @@ class InvokeAIDiffuserComponent:
|
||||
),
|
||||
}
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
if conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts():
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
regions = []
|
||||
for c, r in [
|
||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
masks: list[list[torch.Tensor]] = []
|
||||
for text_conditioning in [conditioning_data.uncond_text, conditioning_data.cond_text]:
|
||||
if text_conditioning.masks is None:
|
||||
# Create a dummy mask for text conditioning that doesn't have region masks.
|
||||
_, _, h, w = x.shape
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
masks.append([torch.ones((1, 1, h, w), dtype=x.dtype)] * len(text_conditioning.text_embeds))
|
||||
else:
|
||||
masks.append(text_conditioning.masks)
|
||||
|
||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||
regions=regions, device=x.device, dtype=x.dtype
|
||||
text_embeds=[conditioning_data.uncond_text.text_embeds, conditioning_data.cond_text.text_embeds],
|
||||
masks=masks,
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||
|
||||
# Note: We pass in the *first* text_embeds entry for both unconditioned and conditioned text embeds. This is the
|
||||
# desired behaviour under 'normal' conditions when there is a single text prompt. In cases where we are doing
|
||||
# regional prompting with multiple prompts, this input will be ignored altogether and the prompt information
|
||||
# will be passed via the RegionalPromptData object.
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||
conditioning_data.uncond_text.text_embeds[0], conditioning_data.cond_text.text_embeds[0]
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
|
Loading…
Reference in New Issue
Block a user