mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.
This commit is contained in:
@ -20,6 +20,8 @@ class ExtraConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
|
||||
@ -30,6 +32,8 @@ class BasicConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
|
@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
w //= 2
|
||||
assert h * w == query_seq_len
|
||||
|
||||
# Convert the bool masks to float masks.
|
||||
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
|
||||
|
||||
# Apply max-pooling to resize the masks to the target spatial dimensions.
|
||||
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
|
||||
# here.
|
||||
|
@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
# HACK(ryand): Figure out how to know the target device/dtype.
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
|
||||
else:
|
||||
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
|
||||
# CPU.
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
|
||||
|
||||
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
|
||||
SDXLConditioningInfo).
|
||||
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
|
||||
- If all masks are None, skips all mask processing.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
(text_embedding, regional_prompt_data)
|
||||
- text_embedding: The concatenated text embeddings.
|
||||
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
|
||||
"""
|
||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
# How does Compel handle this? Options that come to mind:
|
||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||
# this is likely the global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
|
Reference in New Issue
Block a user