mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation.
This commit is contained in:
parent
d132fb4818
commit
b0fcbe552e
@ -1,6 +1,4 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -9,8 +7,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
WithMetadata,
|
WithMetadata,
|
||||||
invocation,
|
invocation,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
@ -24,27 +21,10 @@ class AddConditioningMaskInvocation(BaseInvocation):
|
|||||||
"""Add a mask to an existing conditioning tensor."""
|
"""Add a mask to an existing conditioning tensor."""
|
||||||
|
|
||||||
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.")
|
||||||
image: ImageField = InputField(
|
mask: MaskField = InputField(description="A mask to add to the conditioning tensor.")
|
||||||
description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. "
|
|
||||||
"Pixels <128 are excluded from the mask, pixels >=128 are included in the mask."
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def convert_image_to_mask(image: Image.Image) -> torch.Tensor:
|
|
||||||
"""Convert a PIL image to a uint8 mask tensor."""
|
|
||||||
np_image = np.array(image)
|
|
||||||
torch_image = torch.from_numpy(np_image[:, :, 0])
|
|
||||||
mask = torch_image >= 128
|
|
||||||
return mask.to(dtype=torch.uint8)
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
self.conditioning.mask = self.mask
|
||||||
mask = self.convert_image_to_mask(image)
|
|
||||||
|
|
||||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask"
|
|
||||||
context.services.latents.save(mask_name, mask)
|
|
||||||
|
|
||||||
self.conditioning.mask_name = mask_name
|
|
||||||
return ConditioningOutput(conditioning=self.conditioning)
|
return ConditioningOutput(conditioning=self.conditioning)
|
||||||
|
|
||||||
|
|
||||||
@ -56,33 +36,26 @@ class AddConditioningMaskInvocation(BaseInvocation):
|
|||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
)
|
)
|
||||||
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||||
"""Create a mask image containing a rectangular mask region."""
|
"""Create a rectangular mask."""
|
||||||
|
|
||||||
height: int = InputField(description="The height of the image.")
|
height: int = InputField(description="The height of the entire mask.")
|
||||||
width: int = InputField(description="The width of the image.")
|
width: int = InputField(description="The width of the entire mask.")
|
||||||
y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).")
|
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
|
||||||
y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).")
|
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
|
||||||
x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).")
|
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
|
||||||
x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).")
|
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
mask = np.zeros((self.height, self.width, 3), dtype=np.uint8)
|
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||||
mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255
|
mask[
|
||||||
mask_image = Image.fromarray(mask)
|
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||||
|
] = True
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||||
image=mask_image,
|
context.services.latents.save(mask_name, mask)
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
image_category=ImageCategory.GENERAL,
|
return MaskOutput(
|
||||||
node_id=self.id,
|
mask=MaskField(mask_name=mask_name),
|
||||||
session_id=context.graph_execution_state_id,
|
width=self.width,
|
||||||
is_intermediate=self.is_intermediate,
|
height=self.height,
|
||||||
metadata=self.metadata,
|
|
||||||
workflow=context.workflow,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ImageOutput(
|
|
||||||
image=ImageField(image_name=image_dto.image_name),
|
|
||||||
width=image_dto.width,
|
|
||||||
height=image_dto.height,
|
|
||||||
)
|
)
|
||||||
|
@ -345,10 +345,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||||
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
||||||
|
|
||||||
mask_name = positive_conditioning.mask_name
|
mask = positive_conditioning.mask
|
||||||
mask = None
|
if mask is not None:
|
||||||
if mask_name is not None:
|
mask = context.services.latents.get(mask.mask_name)
|
||||||
mask = context.services.latents.get(mask_name)
|
|
||||||
text_embeddings_masks.append(mask)
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
@ -233,6 +233,24 @@ class BoardField(BaseModel):
|
|||||||
board_id: str = Field(description="The id of the board")
|
board_id: str = Field(description="The id of the board")
|
||||||
|
|
||||||
|
|
||||||
|
class MaskField(BaseModel):
|
||||||
|
"""A mask primitive field."""
|
||||||
|
|
||||||
|
mask_name: str = Field(description="The name of the mask.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("mask_output")
|
||||||
|
class MaskOutput(BaseInvocationOutput):
|
||||||
|
"""A torch mask tensor.
|
||||||
|
dtype: torch.bool
|
||||||
|
shape: (1, height, width).
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask: MaskField = OutputField(description="The mask.")
|
||||||
|
width: int = OutputField(description="The width of the mask in pixels.")
|
||||||
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("image_output")
|
@invocation_output("image_output")
|
||||||
class ImageOutput(BaseInvocationOutput):
|
class ImageOutput(BaseInvocationOutput):
|
||||||
"""Base class for nodes that output a single image"""
|
"""Base class for nodes that output a single image"""
|
||||||
@ -428,10 +446,10 @@ class ConditioningField(BaseModel):
|
|||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
mask_name: Optional[str] = Field(
|
mask: Optional[MaskField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included "
|
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
"regions should be set to 1.",
|
"included regions should be set to 1.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +20,8 @@ class ExtraConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BasicConditioningInfo:
|
class BasicConditioningInfo:
|
||||||
|
"""SD 1/2 text conditioning information produced by Compel."""
|
||||||
|
|
||||||
embeds: torch.Tensor
|
embeds: torch.Tensor
|
||||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
|
|
||||||
@ -30,6 +32,8 @@ class BasicConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
"""SDXL text conditioning information produced by Compel."""
|
||||||
|
|
||||||
pooled_embeds: torch.Tensor
|
pooled_embeds: torch.Tensor
|
||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
w //= 2
|
w //= 2
|
||||||
assert h * w == query_seq_len
|
assert h * w == query_seq_len
|
||||||
|
|
||||||
|
# Convert the bool masks to float masks.
|
||||||
|
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
|
||||||
|
|
||||||
# Apply max-pooling to resize the masks to the target spatial dimensions.
|
# Apply max-pooling to resize the masks to the target spatial dimensions.
|
||||||
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
|
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
|
||||||
# here.
|
# here.
|
||||||
|
@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent:
|
|||||||
def _preprocess_regional_prompt_mask(
|
def _preprocess_regional_prompt_mask(
|
||||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# HACK(ryand): Figure out how to know the target device/dtype.
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||||
return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda")
|
|
||||||
else:
|
tf = torchvision.transforms.Resize(
|
||||||
# HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
# CPU.
|
)
|
||||||
tf = torchvision.transforms.Resize(
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
mask = tf(mask)
|
||||||
)
|
|
||||||
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
|
||||||
mask = tf(mask)
|
|
||||||
|
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent:
|
|||||||
target_height: int,
|
target_height: int,
|
||||||
target_width: int,
|
target_width: int,
|
||||||
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||||
|
"""Prepare text embeddings and associated masks for use in the UNet forward pass.
|
||||||
|
|
||||||
|
- Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or
|
||||||
|
SDXLConditioningInfo).
|
||||||
|
- Preprocesses the masks to match the target height and width, and stacks them into a single tensor.
|
||||||
|
- If all masks are None, skips all mask processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]:
|
||||||
|
(text_embedding, regional_prompt_data)
|
||||||
|
- text_embedding: The concatenated text embeddings.
|
||||||
|
- regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None.
|
||||||
|
"""
|
||||||
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
all_masks_are_none = all(mask is None for mask in masks)
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
|
||||||
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
|
||||||
# the conditioning info, then we shouldn't allow it to be passed in.
|
# the conditioning info, then we shouldn't allow it to be passed in.
|
||||||
|
# How does Compel handle this? Options that come to mind:
|
||||||
|
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
|
||||||
|
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
|
||||||
|
# this is likely the global prompt.
|
||||||
if pooled_embedding is None:
|
if pooled_embedding is None:
|
||||||
pooled_embedding = text_embedding_info.pooled_embeds
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
if add_time_ids is None:
|
if add_time_ids is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user