mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update various comments related to regional prompting, and delete duplicate _preprocess_regional_prompt_mask(...) function.
This commit is contained in:
parent
54971afe44
commit
845c4e93ae
@ -32,15 +32,16 @@ class RegionalPromptData:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
|
||||||
The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in
|
The shape of masks[i] is (num_prompts, height, width). The mask is set to 1.0 in regions where the
|
||||||
regions where the prompt should be applied, and 0.0 elsewhere.
|
prompt should be applied, and 0.0 elsewhere.
|
||||||
|
|
||||||
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
|
||||||
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
|
||||||
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
|
||||||
|
|
||||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
||||||
cross-attention layers).
|
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
||||||
|
explicitly to be sure.
|
||||||
"""
|
"""
|
||||||
attn_masks_by_seq_len = {}
|
attn_masks_by_seq_len = {}
|
||||||
|
|
||||||
|
@ -121,13 +121,14 @@ class RegionalTextConditioningInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
# HACK(ryand): We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. This is
|
||||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
# fundamentally an interface issue, as the SDXL Compel nodes are not designed to be used in the way that
|
||||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
# we use them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
# How does Compel handle this? Options that come to mind:
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
#
|
||||||
# this is likely the global prompt.
|
# An improvement could be to use the pooled embeds from the prompt with the largest region, as this is
|
||||||
|
# most likely to be a global prompt.
|
||||||
if pooled_embedding is None:
|
if pooled_embedding is None:
|
||||||
pooled_embedding = text_embedding_info.pooled_embeds
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
if add_time_ids is None:
|
if add_time_ids is None:
|
||||||
@ -433,28 +434,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
def _preprocess_regional_prompt_mask(
|
|
||||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Preprocess a regional prompt mask to match the target height and width.
|
|
||||||
|
|
||||||
If mask is None, returns a mask of all ones with the target height and width.
|
|
||||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
|
||||||
"""
|
|
||||||
if mask is None:
|
|
||||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
|
||||||
|
|
||||||
tf = torchvision.transforms.Resize(
|
|
||||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
|
||||||
)
|
|
||||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
|
||||||
mask = tf(mask)
|
|
||||||
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def _apply_standard_conditioning(
|
def _apply_standard_conditioning(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
@ -470,8 +449,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
# HACK(ryand): We should only have to call _prepare_text_embeddings once, but we currently re-run it on every
|
# TODO(ryand): We currently call from_text_conditioning_and_masks(...) and from_masks_and_ranges(...) for every
|
||||||
# denoising step.
|
# denoising step. The text conditionings and masks are not changing from step-to-step, so this really only needs
|
||||||
|
# to be done once. While this seems painfully inefficient, the time spent is typically negligible compared to
|
||||||
|
# the forward inference pass of the UNet. The main reason that this hasn't been moved up to eliminate redundancy
|
||||||
|
# is that it is slightly awkward to handle both standard conditioning and sequential conditioning further up the
|
||||||
|
# stack.
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||||
|
Loading…
Reference in New Issue
Block a user