mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Route masks into the RegionalPromptAttnProcessor2_0 processors.
This commit is contained in:
parent
878bbc3527
commit
2d5d370f38
@ -33,7 +33,7 @@ class AddConditioningMaskInvocation(BaseInvocation):
|
|||||||
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
||||||
"""Convert a PIL image to a uint8 mask tensor."""
|
"""Convert a PIL image to a uint8 mask tensor."""
|
||||||
np_image = np.array(image)
|
np_image = np.array(image)
|
||||||
torch_image = torch.from_numpy(np_image[0, :, :])
|
torch_image = torch.from_numpy(np_image[:, :, 0])
|
||||||
mask = torch_image >= 128
|
mask = torch_image >= 128
|
||||||
return mask.to(dtype=torch.uint8)
|
return mask.to(dtype=torch.uint8)
|
||||||
|
|
||||||
|
@ -340,16 +340,24 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
positive_conditioning_list = [positive_conditioning_list]
|
positive_conditioning_list = [positive_conditioning_list]
|
||||||
|
|
||||||
text_embeddings: list[BasicConditioningInfo] = []
|
text_embeddings: list[BasicConditioningInfo] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
for positive_conditioning in positive_conditioning_list:
|
for positive_conditioning in positive_conditioning_list:
|
||||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||||
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
||||||
|
|
||||||
|
mask_name = positive_conditioning.mask_name
|
||||||
|
mask = None
|
||||||
|
if mask_name is not None:
|
||||||
|
mask = context.services.latents.get(mask_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=text_embeddings,
|
text_embeddings=text_embeddings,
|
||||||
|
text_embedding_masks=text_embeddings_masks,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
@ -63,6 +63,8 @@ class IPAdapterConditioningInfo:
|
|||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: list[BasicConditioningInfo]
|
text_embeddings: list[BasicConditioningInfo]
|
||||||
|
text_embedding_masks: list[Optional[torch.Tensor]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +9,26 @@ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
|
|||||||
from diffusers.utils import USE_PEFT_BACKEND
|
from diffusers.utils import USE_PEFT_BACKEND
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Range:
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RegionalPromptData:
|
||||||
|
# The region masks for each prompt.
|
||||||
|
# shape: (batch_size, num_prompts, height, width)
|
||||||
|
# dtype: float*
|
||||||
|
# The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere.
|
||||||
|
masks: torch.Tensor
|
||||||
|
|
||||||
|
# The embedding ranges for each prompt.
|
||||||
|
# The i'th mask is applied to the embeddings in:
|
||||||
|
# encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :]
|
||||||
|
embedding_ranges: list[Range]
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||||
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
|
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
|
||||||
|
|
||||||
@ -19,10 +40,8 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
regional_prompt_data=None,
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
assert regional_prompt_data is None
|
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
@ -2,9 +2,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
@ -16,6 +17,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
@ -308,26 +310,43 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if mask is None:
|
||||||
|
# HACK(ryand): Figure out how to know the target device/dtype.
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
|
||||||
|
else:
|
||||||
|
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
|
||||||
|
# CPU.
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||||
|
mask = tf(mask)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
def _prepare_text_embeddings(
|
def _prepare_text_embeddings(
|
||||||
self, text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
self,
|
||||||
) -> Union[BasicConditioningInfo, SDXLConditioningInfo]:
|
text_embeddings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
||||||
if len(text_embeddings) == 1:
|
masks: list[Optional[torch.Tensor]],
|
||||||
# If there is only one text embedding, we can just return it.
|
target_height: int,
|
||||||
# We short-circuit here, because there are some features that are only supported when there is a single
|
target_width: int,
|
||||||
# text_embedding provided.
|
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||||
return text_embeddings[0]
|
|
||||||
|
|
||||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
text_embedding = []
|
text_embedding = []
|
||||||
pooled_embedding = None
|
pooled_embedding = None
|
||||||
add_time_ids = None
|
add_time_ids = None
|
||||||
|
processed_masks = []
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
embedding_ranges: list[Range] = []
|
||||||
|
|
||||||
for text_embedding_info in text_embeddings:
|
for text_embedding_info, mask in zip(text_embeddings, masks, strict=True):
|
||||||
# TODO(ryand): Having to check this feels super hacky.
|
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
||||||
# Extra conditioning is not supported when there are multiple text embeddings.
|
|
||||||
assert (
|
assert (
|
||||||
text_embedding_info.extra_conditioning is None
|
text_embedding_info.extra_conditioning is None
|
||||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
@ -343,21 +362,35 @@ class InvokeAIDiffuserComponent:
|
|||||||
add_time_ids = text_embedding_info.add_time_ids
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
text_embedding.append(text_embedding_info.embeds)
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
||||||
|
)
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
if not all_masks_are_none:
|
||||||
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, target_height, target_width))
|
||||||
|
|
||||||
text_embedding = torch.cat(text_embedding, dim=1)
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regional_prompt_data = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
# TODO(ryand): Think about at what point a batch dimension should be added to the masks.
|
||||||
|
processed_masks = torch.cat(processed_masks, dim=1)
|
||||||
|
|
||||||
|
regional_prompt_data = RegionalPromptData(masks=processed_masks, embedding_ranges=embedding_ranges)
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
return SDXLConditioningInfo(
|
return SDXLConditioningInfo(
|
||||||
embeds=text_embedding,
|
embeds=text_embedding,
|
||||||
extra_conditioning=None,
|
extra_conditioning=None,
|
||||||
pooled_embeds=pooled_embedding,
|
pooled_embeds=pooled_embedding,
|
||||||
add_time_ids=add_time_ids,
|
add_time_ids=add_time_ids,
|
||||||
)
|
), regional_prompt_data
|
||||||
return BasicConditioningInfo(
|
return BasicConditioningInfo(
|
||||||
embeds=text_embedding,
|
embeds=text_embedding,
|
||||||
extra_conditioning=None,
|
extra_conditioning=None,
|
||||||
)
|
), regional_prompt_data
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
@ -374,11 +407,20 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
text_embeddings = self._prepare_text_embeddings(conditioning_data.text_embeddings)
|
# HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
|
||||||
if len(conditioning_data.text_embeddings) > 1:
|
# denoising step.
|
||||||
cross_attention_kwargs = {"regional_prompt_data": None}
|
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
|
_, _, h, w = x.shape
|
||||||
|
text_embeddings, regional_prompt_data = self._prepare_text_embeddings(
|
||||||
|
text_embeddings=conditioning_data.text_embeddings,
|
||||||
|
masks=conditioning_data.text_embedding_masks,
|
||||||
|
target_height=h,
|
||||||
|
target_width=w,
|
||||||
|
)
|
||||||
|
if regional_prompt_data is not None:
|
||||||
|
cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data}
|
||||||
|
|
||||||
|
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
|
Loading…
Reference in New Issue
Block a user