Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.

This commit is contained in:
Ryan Dick 2024-02-26 17:34:37 -05:00
parent d132fb4818
commit b0fcbe552e
6 changed files with 85 additions and 66 deletions

View File

@ -1,6 +1,4 @@
import numpy as np
import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
@ -9,8 +7,7 @@ from invokeai.app.invocations.baseinvocation import (
WithMetadata,
invocation,
)
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
@invocation(
@ -24,27 +21,10 @@ class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
image: ImageField = InputField(
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
)
@staticmethod
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image)
torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128
return mask.to(dtype=torch.uint8)
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
def invoke(self, context: InvocationContext) -> ConditioningOutput:
image = context.services.images.get_pil_image(self.image.image_name)
mask = self.convert_image_to_mask(image)
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
context.services.latents.save(mask_name, mask)
self.conditioning.mask_name = mask_name
self.conditioning.mask = self.mask
return ConditioningOutput(conditioning=self.conditioning)
@ -56,33 +36,26 @@ class AddConditioningMaskInvocation(BaseInvocation):
version="1.0.0",
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a mask image containing a rectangular mask region."""
"""Create a rectangular mask."""
height: int = InputField(description="The height of the image.")
width: int = InputField(description="The width of the image.")
y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).")
y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).")
x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).")
height: int = InputField(description="The height of the entire mask.")
width: int = InputField(description="The width of the entire mask.")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
def invoke(self, context: InvocationContext) -> ImageOutput:
mask = np.zeros((self.height, self.width, 3), dtype=np.uint8)
mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255
mask_image = Image.fromarray(mask)
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True
image_dto = context.services.images.create(
image=mask_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
width=self.width,
height=self.height,
)

View File

@ -345,10 +345,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
mask = positive_conditioning.mask
if mask is not None:
mask = context.services.latents.get(mask.mask_name)
text_embeddings_masks.append(mask)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)

View File

@ -233,6 +233,24 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board")
class MaskField(BaseModel):
"""A mask primitive field."""
mask_name: str = Field(description="The name of the mask.")
@invocation_output("mask_output")
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor.
dtype: torch.bool
shape: (1, height, width).
"""
mask: MaskField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("image_output")
class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
@ -428,10 +446,10 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask_name: Optional[str] = Field(
mask: Optional[MaskField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
"regions should be set to 1.",
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to 1.",
)

View File

@ -20,6 +20,8 @@ class ExtraConditioningInfo:
@dataclass
class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo]
@ -30,6 +32,8 @@ class BasicConditioningInfo:
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor

View File

@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
w //= 2
assert h * w == query_seq_len
# Convert the bool masks to float masks.
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
# Apply max-pooling to resize the masks to the target spatial dimensions.
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
# here.

View File

@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
# HACK(ryand): Figure out how to know the target device/dtype.
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
else:
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
# CPU.
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
target_height: int,
target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
SDXLConditioningInfo).
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
- If all masks are None, skips all mask processing.
Returns:
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
(text_embedding, regional_prompt_data)
- text_embedding: The concatenated text embeddings.
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
"""
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None: