Route masks into the RegionalPromptAttnProcessor2_0 processors.

This commit is contained in:
Ryan Dick 2024-02-16 19:35:24 -05:00
parent 878bbc3527
commit 2d5d370f38
5 changed files with 94 additions and 23 deletions

View File

@ -33,7 +33,7 @@ class AddConditioningMaskInvocation(BaseInvocation):
def convert_image_to_mask(image: Image.Image) -> torch.Tensor: def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor.""" """Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image) np_image = np.array(image)
torch_image = torch.from_numpy(np_image[0, :, :]) torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128 mask = torch_image >= 128
return mask.to(dtype=torch.uint8) return mask.to(dtype=torch.uint8)

View File

@ -340,16 +340,24 @@ class DenoiseLatentsInvocation(BaseInvocation):
positive_conditioning_list = [positive_conditioning_list] positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[BasicConditioningInfo] = [] text_embeddings: list[BasicConditioningInfo] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
for positive_conditioning in positive_conditioning_list: for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
text_embeddings_masks.append(mask)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
text_embeddings=text_embeddings, text_embeddings=text_embeddings,
text_embedding_masks=text_embeddings_masks,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(

View File

@ -63,6 +63,8 @@ class IPAdapterConditioningInfo:
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
text_embeddings: list[BasicConditioningInfo] text_embeddings: list[BasicConditioningInfo]
text_embedding_masks: list[Optional[torch.Tensor]]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).

View File

@ -1,4 +1,5 @@
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
@ -8,6 +9,26 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.utils import USE_PEFT_BACKEND from diffusers.utils import USE_PEFT_BACKEND
@dataclass
class Range:
start: int
end: int
@dataclass
class RegionalPromptData:
# The region masks for each prompt.
# shape: (batch_size, num_prompts, height, width)
# dtype: float*
# The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere.
masks: torch.Tensor
# The embedding ranges for each prompt.
# The i'th mask is applied to the embeddings in:
# encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :]
embedding_ranges: list[Range]
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
"""An attention processor that supports regional prompt attention for PyTorch 2.0.""" """An attention processor that supports regional prompt attention for PyTorch 2.0."""
@ -19,10 +40,8 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, scale: float = 1.0,
regional_prompt_data=None, regional_prompt_data: Optional[RegionalPromptData] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert regional_prompt_data is None
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import math import math
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Tuple, Union
import torch import torch
import torchvision
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -16,6 +17,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
PostprocessingSettings, PostprocessingSettings,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
from .cross_attention_control import ( from .cross_attention_control import (
CrossAttentionType, CrossAttentionType,
@ -308,26 +310,43 @@ class InvokeAIDiffuserComponent:
return torch.cat([unconditioning, conditioning]), encoder_attention_mask return torch.cat([unconditioning, conditioning]), encoder_attention_mask
# methods below are called from do_diffusion_step and should be considered private to this class. def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
if mask is None:
# HACK(ryand): Figure out how to know the target device/dtype.
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
else:
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
# CPU.
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def _prepare_text_embeddings( def _prepare_text_embeddings(
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]] self,
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]: text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
if len(text_embeddings) == 1: masks: list[Optional[torch.Tensor]],
# If there is only one text embedding, we can just return it. target_height: int,
# We short-circuit here, because there are some features that are only supported when there is a single target_width: int,
# text_embedding provided. ) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
return text_embeddings[0]
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = [] text_embedding = []
pooled_embedding = None pooled_embedding = None
add_time_ids = None add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info in text_embeddings: for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
# TODO(ryand): Having to check this feels super hacky. # HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
# Extra conditioning is not supported when there are multiple text embeddings.
assert ( assert (
text_embedding_info.extra_conditioning is None text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control or not text_embedding_info.extra_conditioning.wants_cross_attention_control
@ -343,21 +362,35 @@ class InvokeAIDiffuserComponent:
add_time_ids = text_embedding_info.add_time_ids add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds) text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
text_embedding = torch.cat(text_embedding, dim=1) text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
regional_prompt_data = None
if not all_masks_are_none:
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
processed_masks = torch.cat(processed_masks, dim=1)
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
if is_sdxl: if is_sdxl:
return SDXLConditioningInfo( return SDXLConditioningInfo(
embeds=text_embedding, embeds=text_embedding,
extra_conditioning=None, extra_conditioning=None,
pooled_embeds=pooled_embedding, pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids, add_time_ids=add_time_ids,
) ), regional_prompt_data
return BasicConditioningInfo( return BasicConditioningInfo(
embeds=text_embedding, embeds=text_embedding,
extra_conditioning=None, extra_conditioning=None,
) ), regional_prompt_data
def _apply_standard_conditioning( def _apply_standard_conditioning(
self, self,
@ -374,11 +407,20 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings) # HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
if len(conditioning_data.text_embeddings) > 1: # denoising step.
cross_attention_kwargs = {"regional_prompt_data": None}
cross_attention_kwargs = None cross_attention_kwargs = None
_, _, h, w = x.shape
text_embeddings, regional_prompt_data = self._prepare_text_embeddings(
text_embeddings=conditioning_data.text_embeddings,
masks=conditioning_data.text_embedding_masks,
target_height=h,
target_width=w,
)
if regional_prompt_data is not None:
cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = { cross_attention_kwargs = {