mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Handle conditioned and unconditioned text conditioning in the same way for regional prompt attention.
This commit is contained in:
parent
b0fcbe552e
commit
2966c8de2c
@ -65,6 +65,7 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
# TODO(ryand): Support masks for unconditioned_embeddings.
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: list[BasicConditioningInfo]
|
||||
text_embedding_masks: list[Optional[torch.Tensor]]
|
||||
|
@ -15,78 +15,102 @@ class Range:
|
||||
end: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegionalPromptData:
|
||||
# The region masks for each prompt.
|
||||
# shape: (batch_size, num_prompts, height, width)
|
||||
# dtype: float*
|
||||
# The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere.
|
||||
masks: torch.Tensor
|
||||
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
|
||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
||||
|
||||
# The embedding ranges for each prompt.
|
||||
# The i'th mask is applied to the embeddings in:
|
||||
# encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :]
|
||||
embedding_ranges: list[Range]
|
||||
@classmethod
|
||||
def from_masks_and_ranges(
|
||||
cls,
|
||||
masks: list[torch.Tensor],
|
||||
embedding_ranges: list[list[Range]],
|
||||
key_seq_len: int,
|
||||
# TODO(ryand): Pass in a list of downscale factors?
|
||||
max_downscale_factor: int = 8,
|
||||
):
|
||||
"""Construct a `RegionalPromptData` object.
|
||||
|
||||
Args:
|
||||
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
||||
The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in
|
||||
regions where the prompt should be applied, and 0.0 elsewhere.
|
||||
|
||||
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
||||
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
||||
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
||||
|
||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||
cross-attention layers).
|
||||
"""
|
||||
attn_masks_by_seq_len = {}
|
||||
|
||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||
# length of s.
|
||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True):
|
||||
batch_attn_masks_by_seq_len.append({})
|
||||
|
||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||
batch_masks = batch_masks.to(dtype=torch.float32)
|
||||
|
||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||
downscale_factor = 1
|
||||
while downscale_factor <= max_downscale_factor:
|
||||
_, num_prompts, h, w = batch_masks.shape
|
||||
query_seq_len = h * w
|
||||
|
||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
||||
|
||||
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
||||
# `encoder_hidden_states`.
|
||||
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
||||
# TODO(ryand): What device / dtype should this be?
|
||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||
|
||||
for prompt_idx, embedding_range in enumerate(batch_ranges):
|
||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||
:, prompt_idx, :, :
|
||||
]
|
||||
|
||||
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
||||
|
||||
downscale_factor *= 2
|
||||
if downscale_factor <= max_downscale_factor:
|
||||
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||
# regions to be lost entirely.
|
||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||
# potentially use a weighted mask rather than a binary mask.
|
||||
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
||||
|
||||
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
||||
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
||||
attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
||||
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
||||
)
|
||||
|
||||
return cls(attn_masks_by_seq_len)
|
||||
|
||||
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
||||
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
|
||||
|
||||
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
|
||||
it changes at each downscaling level in the model.
|
||||
|
||||
key_seq_len is the length of the expected prompt embeddings.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The masks.
|
||||
shape: (batch_size, query_seq_len, key_seq_len).
|
||||
dtype: float
|
||||
The mask is a binary mask with values of 0.0 and 1.0.
|
||||
"""
|
||||
return self._attn_masks_by_seq_len[query_seq_len]
|
||||
|
||||
|
||||
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
|
||||
|
||||
def _prepare_regional_prompt_attention_mask(
|
||||
self,
|
||||
regional_prompt_data: RegionalPromptData,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
orig_attn_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Infer the current spatial dimensions from the shape of `hidden_states`.
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
per_prompt_query_masks = regional_prompt_data.masks
|
||||
_, _, h, w = per_prompt_query_masks.shape
|
||||
|
||||
# Downsample by factors of 2 until the spatial dimensions match the current query sequence length.
|
||||
scale_factor = 1
|
||||
while h * w > query_seq_len:
|
||||
scale_factor *= 2
|
||||
h //= 2
|
||||
w //= 2
|
||||
assert h * w == query_seq_len
|
||||
|
||||
# Convert the bool masks to float masks.
|
||||
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
|
||||
|
||||
# Apply max-pooling to resize the masks to the target spatial dimensions.
|
||||
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
|
||||
# here.
|
||||
per_prompt_query_masks = F.max_pool2d(per_prompt_query_masks, kernel_size=scale_factor, stride=scale_factor)
|
||||
batch_size, num_prompts, resized_h, resized_w = per_prompt_query_masks.shape
|
||||
assert resized_h == h and resized_w == w
|
||||
|
||||
# Flatten the spatial dimensions of the masks.
|
||||
# Shape after reshape: (batch_size, num_prompts, query_seq_len)
|
||||
per_prompt_query_masks = per_prompt_query_masks.reshape((batch_size, num_prompts, -1, 1))
|
||||
|
||||
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
||||
# `encoder_hidden_states`.
|
||||
|
||||
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
||||
_, key_seq_len, _ = encoder_hidden_states.shape
|
||||
# HACK(ryand): We are assuming the batch size.
|
||||
attn_mask = torch.zeros((2, query_seq_len, key_seq_len), device=hidden_states.device)
|
||||
|
||||
for i, embedding_range in enumerate(regional_prompt_data.embedding_ranges):
|
||||
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. This is too fragile
|
||||
# to merge.
|
||||
attn_mask[1, :, embedding_range.start : embedding_range.end] = per_prompt_query_masks[:, i, :, :]
|
||||
|
||||
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. We are also assuming
|
||||
# the intent of attn_mask. And we shouldn't have to do this awkward mask type conversion.
|
||||
orig_mask = torch.zeros_like(orig_attn_mask[0, ...])
|
||||
orig_mask[orig_attn_mask[0, ...] > -0.5] = 1.0
|
||||
attn_mask[0, ...] = orig_mask
|
||||
|
||||
return attn_mask > 0.5
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
@ -114,9 +138,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
if encoder_hidden_states is not None:
|
||||
assert regional_prompt_data is not None
|
||||
assert attention_mask is not None
|
||||
attention_mask = self._prepare_regional_prompt_attention_mask(
|
||||
regional_prompt_data, hidden_states, encoder_hidden_states, attention_mask
|
||||
_, query_seq_len, _ = hidden_states.shape
|
||||
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
|
||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||
dtype=attention_mask.dtype, device=attention_mask.device
|
||||
)
|
||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||
|
||||
attention_mask = prompt_region_attention_mask + attention_mask
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
@ -36,6 +36,142 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
]
|
||||
|
||||
|
||||
class RegionalTextConditioningInfo:
|
||||
def __init__(
|
||||
self,
|
||||
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
embedding_ranges: Optional[list[Range]] = None,
|
||||
):
|
||||
"""Initialize a RegionalTextConditioningInfo.
|
||||
|
||||
Args:
|
||||
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
|
||||
after concatenating the embeddings for all regions.
|
||||
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
|
||||
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
|
||||
"""
|
||||
self.text_conditioning = text_conditioning
|
||||
self.masks = masks
|
||||
self.embedding_ranges = embedding_ranges
|
||||
|
||||
assert (self.masks is None) == (self.embedding_ranges is None)
|
||||
if self.masks is not None:
|
||||
assert self.masks.shape[1] == len(self.embedding_ranges)
|
||||
|
||||
def has_region_masks(self):
|
||||
if self.masks is None:
|
||||
return False
|
||||
return any(mask is not None for mask in self.masks)
|
||||
|
||||
def is_sdxl(self):
|
||||
return isinstance(self.text_conditioning, SDXLConditioningInfo)
|
||||
|
||||
@classmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
@classmethod
|
||||
def from_text_conditioning_and_masks(
|
||||
cls,
|
||||
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
):
|
||||
if masks is None:
|
||||
masks = [None] * len(text_conditionings)
|
||||
assert len(text_conditionings) == len(masks)
|
||||
|
||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
|
||||
text_embedding = []
|
||||
pooled_embedding = None
|
||||
add_time_ids = None
|
||||
processed_masks = []
|
||||
cur_text_embedding_len = 0
|
||||
embedding_ranges: list[Range] = []
|
||||
|
||||
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
|
||||
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
|
||||
assert (
|
||||
text_embedding_info.extra_conditioning is None
|
||||
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
# How does Compel handle this? Options that come to mind:
|
||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||
# this is likely the global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
add_time_ids = text_embedding_info.add_time_ids
|
||||
|
||||
text_embedding.append(text_embedding_info.embeds)
|
||||
embedding_ranges.append(
|
||||
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
|
||||
)
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
|
||||
if not all_masks_are_none:
|
||||
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||
|
||||
text_embedding = torch.cat(text_embedding, dim=1)
|
||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||
|
||||
if not all_masks_are_none:
|
||||
processed_masks = torch.cat(processed_masks, dim=1)
|
||||
else:
|
||||
processed_masks = None
|
||||
embedding_ranges = None
|
||||
|
||||
if is_sdxl:
|
||||
return cls(
|
||||
text_conditioning=SDXLConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
pooled_embeds=pooled_embedding,
|
||||
add_time_ids=add_time_ids,
|
||||
),
|
||||
masks=processed_masks,
|
||||
embedding_ranges=embedding_ranges,
|
||||
)
|
||||
return cls(
|
||||
text_conditioning=BasicConditioningInfo(
|
||||
embeds=text_embedding,
|
||||
extra_conditioning=None,
|
||||
),
|
||||
masks=processed_masks,
|
||||
embedding_ranges=embedding_ranges,
|
||||
)
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@ -59,7 +195,6 @@ class InvokeAIDiffuserComponent:
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
@ -433,14 +568,44 @@ class InvokeAIDiffuserComponent:
|
||||
# denoising step.
|
||||
cross_attention_kwargs = None
|
||||
_, _, h, w = x.shape
|
||||
text_embeddings, regional_prompt_data = self._prepare_text_embeddings(
|
||||
text_embeddings=conditioning_data.text_embeddings,
|
||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||
text_conditionings=conditioning_data.text_embeddings,
|
||||
masks=conditioning_data.text_embedding_masks,
|
||||
target_height=h,
|
||||
target_width=w,
|
||||
latent_height=h,
|
||||
latent_width=w,
|
||||
)
|
||||
if regional_prompt_data is not None:
|
||||
cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data}
|
||||
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||
text_conditionings=[conditioning_data.unconditioned_embeddings],
|
||||
masks=[None],
|
||||
latent_height=h,
|
||||
latent_width=w,
|
||||
)
|
||||
|
||||
if cond_text.has_region_masks() or uncond_text.has_region_masks():
|
||||
masks = []
|
||||
embedding_ranges = []
|
||||
for c in [uncond_text, cond_text]:
|
||||
if c.has_region_masks():
|
||||
masks.append(c.masks)
|
||||
embedding_ranges.append(c.embedding_ranges)
|
||||
else:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
|
||||
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
|
||||
|
||||
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
|
||||
# embeddings will be padded to match this length.
|
||||
key_seq_len = 0
|
||||
for c in [uncond_text, cond_text]:
|
||||
_, seq_len, _ = c.text_conditioning.embeds.shape
|
||||
if seq_len > key_seq_len:
|
||||
key_seq_len = seq_len
|
||||
|
||||
cross_attention_kwargs = {
|
||||
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
|
||||
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
|
||||
)
|
||||
}
|
||||
|
||||
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
@ -455,27 +620,18 @@ class InvokeAIDiffuserComponent:
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(text_embeddings) is SDXLConditioningInfo:
|
||||
if cond_text.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
text_embeddings.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
text_embeddings.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
|
||||
),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
|
||||
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
|
Loading…
x
Reference in New Issue
Block a user