mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.
This commit is contained in:
parent
d132fb4818
commit
b0fcbe552e
@ -1,6 +1,4 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -9,8 +7,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
WithMetadata,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
|
||||
|
||||
|
||||
@invocation(
|
||||
@ -24,27 +21,10 @@ class AddConditioningMaskInvocation(BaseInvocation):
|
||||
"""Add a mask to an existing conditioning tensor."""
|
||||
|
||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||
image: ImageField = InputField(
|
||||
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
|
||||
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert a PIL image to a uint8 mask tensor."""
|
||||
np_image = np.array(image)
|
||||
torch_image = torch.from_numpy(np_image[:, :, 0])
|
||||
mask = torch_image >= 128
|
||||
return mask.to(dtype=torch.uint8)
|
||||
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = self.convert_image_to_mask(image)
|
||||
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
self.conditioning.mask_name = mask_name
|
||||
self.conditioning.mask = self.mask
|
||||
return ConditioningOutput(conditioning=self.conditioning)
|
||||
|
||||
|
||||
@ -56,33 +36,26 @@ class AddConditioningMaskInvocation(BaseInvocation):
|
||||
version="1.0.0",
|
||||
)
|
||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Create a mask image containing a rectangular mask region."""
|
||||
"""Create a rectangular mask."""
|
||||
|
||||
height: int = InputField(description="The height of the image.")
|
||||
width: int = InputField(description="The width of the image.")
|
||||
y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).")
|
||||
y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).")
|
||||
x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).")
|
||||
x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).")
|
||||
height: int = InputField(description="The height of the entire mask.")
|
||||
width: int = InputField(description="The width of the entire mask.")
|
||||
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
mask = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
||||
mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255
|
||||
mask_image = Image.fromarray(mask)
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||
mask[
|
||||
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||
] = True
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=mask_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
return MaskOutput(
|
||||
mask=MaskField(mask_name=mask_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
|
@ -345,10 +345,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
||||
|
||||
mask_name = positive_conditioning.mask_name
|
||||
mask = None
|
||||
if mask_name is not None:
|
||||
mask = context.services.latents.get(mask_name)
|
||||
mask = positive_conditioning.mask
|
||||
if mask is not None:
|
||||
mask = context.services.latents.get(mask.mask_name)
|
||||
text_embeddings_masks.append(mask)
|
||||
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
|
@ -233,6 +233,24 @@ class BoardField(BaseModel):
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class MaskField(BaseModel):
|
||||
"""A mask primitive field."""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask.")
|
||||
|
||||
|
||||
@invocation_output("mask_output")
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor.
|
||||
dtype: torch.bool
|
||||
shape: (1, height, width).
|
||||
"""
|
||||
|
||||
mask: MaskField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
@ -428,10 +446,10 @@ class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask_name: Optional[str] = Field(
|
||||
mask: Optional[MaskField] = Field(
|
||||
default=None,
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
||||
"regions should be set to 1.",
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to 1.",
|
||||
)
|
||||
|
||||
|
||||
|
@ -20,6 +20,8 @@ class ExtraConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
"""SD 1/2 text conditioning information produced by Compel."""
|
||||
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
|
||||
@ -30,6 +32,8 @@ class BasicConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
"""SDXL text conditioning information produced by Compel."""
|
||||
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
|
@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
||||
w //= 2
|
||||
assert h * w == query_seq_len
|
||||
|
||||
# Convert the bool masks to float masks.
|
||||
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
|
||||
|
||||
# Apply max-pooling to resize the masks to the target spatial dimensions.
|
||||
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
|
||||
# here.
|
||||
|
@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
|
||||
def _preprocess_regional_prompt_mask(
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||
"""
|
||||
if mask is None:
|
||||
# HACK(ryand): Figure out how to know the target device/dtype.
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
|
||||
else:
|
||||
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
|
||||
# CPU.
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||
|
||||
tf = torchvision.transforms.Resize(
|
||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
return mask
|
||||
|
||||
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
|
||||
target_height: int,
|
||||
target_width: int,
|
||||
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
|
||||
|
||||
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
|
||||
SDXLConditioningInfo).
|
||||
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
|
||||
- If all masks are None, skips all mask processing.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||
(text_embedding, regional_prompt_data)
|
||||
- text_embedding: The concatenated text embeddings.
|
||||
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
|
||||
"""
|
||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||
|
||||
all_masks_are_none = all(mask is None for mask in masks)
|
||||
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
|
||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||
# How does Compel handle this? Options that come to mind:
|
||||
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||
# this is likely the global prompt.
|
||||
if pooled_embedding is None:
|
||||
pooled_embedding = text_embedding_info.pooled_embeds
|
||||
if add_time_ids is None:
|
||||
|
Loading…
Reference in New Issue
Block a user