Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.

This commit is contained in:
Ryan Dick 2024-02-26 17:34:37 -05:00
parent d132fb4818
commit b0fcbe552e
6 changed files with 85 additions and 66 deletions

View File

@ -1,6 +1,4 @@
import numpy as np
import torch import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
@ -9,8 +7,7 @@ from invokeai.app.invocations.baseinvocation import (
WithMetadata, WithMetadata,
invocation, invocation,
) )
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
@invocation( @invocation(
@ -24,27 +21,10 @@ class AddConditioningMaskInvocation(BaseInvocation):
"""Add a mask to an existing conditioning tensor.""" """Add a mask to an existing conditioning tensor."""
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
image: ImageField = InputField( mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
)
@staticmethod
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
"""Convert a PIL image to a uint8 mask tensor."""
np_image = np.array(image)
torch_image = torch.from_numpy(np_image[:, :, 0])
mask = torch_image >= 128
return mask.to(dtype=torch.uint8)
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
image = context.services.images.get_pil_image(self.image.image_name) self.conditioning.mask = self.mask
mask = self.convert_image_to_mask(image)
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
context.services.latents.save(mask_name, mask)
self.conditioning.mask_name = mask_name
return ConditioningOutput(conditioning=self.conditioning) return ConditioningOutput(conditioning=self.conditioning)
@ -56,33 +36,26 @@ class AddConditioningMaskInvocation(BaseInvocation):
version="1.0.0", version="1.0.0",
) )
class RectangleMaskInvocation(BaseInvocation, WithMetadata): class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a mask image containing a rectangular mask region.""" """Create a rectangular mask."""
height: int = InputField(description="The height of the image.") height: int = InputField(description="The height of the entire mask.")
width: int = InputField(description="The width of the image.") width: int = InputField(description="The width of the entire mask.")
y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).") y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).") x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).") rectangle_height: int = InputField(description="The height of the rectangular masked region.")
x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).") rectangle_width: int = InputField(description="The width of the rectangular masked region.")
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
mask = np.zeros((self.height, self.width, 3), dtype=np.uint8) mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255 mask[
mask_image = Image.fromarray(mask) :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
] = True
image_dto = context.services.images.create( mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
image=mask_image, context.services.latents.save(mask_name, mask)
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, return MaskOutput(
node_id=self.id, mask=MaskField(mask_name=mask_name),
session_id=context.graph_execution_state_id, width=self.width,
is_intermediate=self.is_intermediate, height=self.height,
metadata=self.metadata,
workflow=context.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
) )

View File

@ -345,10 +345,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
mask_name = positive_conditioning.mask_name mask = positive_conditioning.mask
mask = None if mask is not None:
if mask_name is not None: mask = context.services.latents.get(mask.mask_name)
mask = context.services.latents.get(mask_name)
text_embeddings_masks.append(mask) text_embeddings_masks.append(mask)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)

View File

@ -233,6 +233,24 @@ class BoardField(BaseModel):
board_id: str = Field(description="The id of the board") board_id: str = Field(description="The id of the board")
class MaskField(BaseModel):
"""A mask primitive field."""
mask_name: str = Field(description="The name of the mask.")
@invocation_output("mask_output")
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor.
dtype: torch.bool
shape: (1, height, width).
"""
mask: MaskField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")
@invocation_output("image_output") @invocation_output("image_output")
class ImageOutput(BaseInvocationOutput): class ImageOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image""" """Base class for nodes that output a single image"""
@ -428,10 +446,10 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value""" """A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor") conditioning_name: str = Field(description="The name of conditioning tensor")
mask_name: Optional[str] = Field( mask: Optional[MaskField] = Field(
default=None, default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included " description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"regions should be set to 1.", "included regions should be set to 1.",
) )

View File

@ -20,6 +20,8 @@ class ExtraConditioningInfo:
@dataclass @dataclass
class BasicConditioningInfo: class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor embeds: torch.Tensor
extra_conditioning: Optional[ExtraConditioningInfo] extra_conditioning: Optional[ExtraConditioningInfo]
@ -30,6 +32,8 @@ class BasicConditioningInfo:
@dataclass @dataclass
class SDXLConditioningInfo(BasicConditioningInfo): class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor add_time_ids: torch.Tensor

View File

@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
w //= 2 w //= 2
assert h * w == query_seq_len assert h * w == query_seq_len
# Convert the bool masks to float masks.
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
# Apply max-pooling to resize the masks to the target spatial dimensions. # Apply max-pooling to resize the masks to the target spatial dimensions.
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation # TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
# here. # here.

View File

@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
def _preprocess_regional_prompt_mask( def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int self, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor: ) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None: if mask is None:
# HACK(ryand): Figure out how to know the target device/dtype. return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
else: tf = torchvision.transforms.Resize(
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
# CPU. )
tf = torchvision.transforms.Resize( mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST mask = tf(mask)
)
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask return mask
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
target_height: int, target_height: int,
target_width: int, target_width: int,
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: ) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
SDXLConditioningInfo).
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
- If all masks are None, skips all mask processing.
Returns:
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
(text_embedding, regional_prompt_data)
- text_embedding: The concatenated text embeddings.
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
"""
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks) all_masks_are_none = all(mask is None for mask in masks)
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in. # the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None: if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None: if add_time_ids is None: