mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update various comments related to regional prompting, and delete duplicate _preprocess_regional_prompt_mask(...) function.
This commit is contained in:
parent
54971afe44
commit
845c4e93ae
@ -32,15 +32,16 @@ class RegionalPromptData:
|
||||
|
||||
Args:
|
||||
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
||||
The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in
|
||||
regions where the prompt should be applied, and 0.0 elsewhere.
|
||||
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
|
||||
prompt should be applied, and 0.0 elsewhere.
|
||||
|
||||
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
||||
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
||||
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
||||
|
||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||
cross-attention layers).
|
||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||
explicitly to be sure.
|
||||
"""
|
||||
attn_masks_by_seq_len = {}
|
||||
|
||||
|
@ -121,13 +121,14 @@ class RegionalTextConditioningInfo:
|
||||
)
|
||||
|
||||
if is_sdxl:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
# How does Compel handle this? Options that come to mind:
|
||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||
# this is likely the global prompt.
|
||||
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
||||
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||
#
|
||||
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
||||
# most likely to be a global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
@ -433,28 +434,6 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
def _apply_standard_conditioning(
|
||||
self,
|
||||
x,
|
||||
@ -470,8 +449,12 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
# HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
|
||||
# denoising step.
|
||||
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
|
||||
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
|
||||
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
|
||||
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
|
||||
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
|
||||
# stack.
|
||||
cross_attention_kwargs = None
|
||||
_, _, h, w = x.shape
|
||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||
|
Loading…
Reference in New Issue
Block a user