Update various comments related to regional prompting, and delete duplicate _preprocess_regional_prompt_mask(...) function.

This commit is contained in:
Ryan Dick 2024-02-28 10:20:22 -05:00
parent 54971afe44
commit 845c4e93ae
2 changed files with 18 additions and 34 deletions

View File

@ -32,15 +32,16 @@ class RegionalPromptData:
Args:
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in
regions where the prompt should be applied, and 0.0 elsewhere.
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
prompt should be applied, and 0.0 elsewhere.
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers).
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
explicitly to be sure.
"""
attn_masks_by_seq_len = {}

View File

@ -121,13 +121,14 @@ class RegionalTextConditioningInfo:
)
if is_sdxl:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
# pretty major breaking change to a popular node, so for now we use this hack.
#
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
# most likely to be a global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
@ -433,28 +434,6 @@ class InvokeAIDiffuserComponent:
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
def _apply_standard_conditioning(
self,
x,
@ -470,8 +449,12 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
# HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
# denoising step.
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
# stack.
cross_attention_kwargs = None
_, _, h, w = x.shape
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(