Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.

This commit is contained in:
Ryan Dick
2024-02-26 17:34:37 -05:00
parent d132fb4818
commit b0fcbe552e
6 changed files with 85 additions and 66 deletions

View File

@ -20,6 +20,8 @@ class ExtraConditioningInfo:
@dataclass
class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo]
@ -30,6 +32,8 @@ class BasicConditioningInfo:
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor

View File

@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
w //= 2
assert h * w == query_seq_len
# Convert the bool masks to float masks.
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
# Apply max-pooling to resize the masks to the target spatial dimensions.
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
# here.

View File

@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
# HACK(ryand): Figure out how to know the target device/dtype.
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
else:
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
# CPU.
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
target_height: int,
target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
SDXLConditioningInfo).
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
- If all masks are None, skips all mask processing.
Returns:
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
(text_embedding, regional_prompt_data)
- text_embedding: The concatenated text embeddings.
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
"""
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None: