Route masks into the RegionalPromptAttnProcessor2_0 processors.

This commit is contained in:
Ryan Dick 2024-02-16 19:35:24 -05:00
parent 878bbc3527
commit 2d5d370f38
5 changed files with 94 additions and 23 deletions

View File

@ -33,7 +33,7 @@ class AddConditioningMaskInvocation(BaseInvocation):
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image)
torch_image = torch.from_numpy(np_image[0, :, :])
torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128
return mask.to(dtype=torch.uint8)

View File

@ -340,16 +340,24 @@ class DenoiseLatentsInvocation(BaseInvocation):
positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[BasicConditioningInfo] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
text_embeddings_masks.append(mask)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
text_embeddings=text_embeddings,
text_embedding_masks=text_embeddings_masks,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings(

View File

@ -63,6 +63,8 @@ class IPAdapterConditioningInfo:
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: list[BasicConditioningInfo]
text_embedding_masks: list[Optional[torch.Tensor]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).

View File

@ -1,4 +1,5 @@
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
import torch
@ -8,6 +9,26 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND
@dataclass
class Range:
start: int
end: int
@dataclass
class RegionalPromptData:
# The region masks for each prompt.
# shape: (batch_size, num_prompts, height, width)
# dtype: float*
# The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere.
masks: torch.Tensor
# The embedding ranges for each prompt.
# The i'th mask is applied to the embeddings in:
# encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :]
embedding_ranges: list[Range]
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
@ -19,10 +40,8 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
regional_prompt_data=None,
regional_prompt_data: Optional[RegionalPromptData] = None,
) -> torch.FloatTensor:
assert regional_prompt_data is None
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
import torchvision
from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
@ -16,6 +17,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
PostprocessingSettings,
SDXLConditioningInfo,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
from .cross_attention_control import (
CrossAttentionType,
@ -308,26 +310,43 @@ class InvokeAIDiffuserComponent:
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class.
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
if mask is None:
# HACK(ryand): Figure out how to know the target device/dtype.
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
else:
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
# CPU.
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def _prepare_text_embeddings(
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]]
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]:
if len(text_embeddings) == 1:
# If there is only one text embedding, we can just return it.
# We short-circuit here, because there are some features that are only supported when there is a single
# text_embedding provided.
return text_embeddings[0]
self,
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: list[Optional[torch.Tensor]],
target_height: int,
target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info in text_embeddings:
# TODO(ryand): Having to check this feels super hacky.
# Extra conditioning is not supported when there are multiple text embeddings.
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
@ -343,21 +362,35 @@ class InvokeAIDiffuserComponent:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regional_prompt_data = None
if not all_masks_are_none:
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
processed_masks = torch.cat(processed_masks, dim=1)
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
if is_sdxl:
return SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
)
), regional_prompt_data
return BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
)
), regional_prompt_data
def _apply_standard_conditioning(
self,
@ -374,11 +407,20 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings)
if len(conditioning_data.text_embeddings) > 1:
cross_attention_kwargs = {"regional_prompt_data": None}
# HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
# denoising step.
cross_attention_kwargs = None
_, _, h, w = x.shape
text_embeddings, regional_prompt_data = self._prepare_text_embeddings(
text_embeddings=conditioning_data.text_embeddings,
masks=conditioning_data.text_embedding_masks,
target_height=h,
target_width=w,
)
if regional_prompt_data is not None:
cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {