From b0fcbe552ee53ae6084ab08fcd0c3394c6e836d1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 26 Feb 2024 17:34:37 -0500 Subject: [PATCH] Tidy invocation interfaces for RectangleMaskInvocation and AddConditioningMaskInvocation. --- invokeai/app/invocations/conditioning.py | 71 ++++++------------- invokeai/app/invocations/latent.py | 7 +- invokeai/app/invocations/primitives.py | 24 ++++++- .../diffusion/conditioning_data.py | 4 ++ .../diffusion/regional_prompt_attention.py | 3 + .../diffusion/shared_invokeai_diffusion.py | 42 ++++++++--- 6 files changed, 85 insertions(+), 66 deletions(-) diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py index 323ae6c038..3102951408 100644 --- a/invokeai/app/invocations/conditioning.py +++ b/invokeai/app/invocations/conditioning.py @@ -1,6 +1,4 @@ -import numpy as np import torch -from PIL import Image from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -9,8 +7,7 @@ from invokeai.app.invocations.baseinvocation import ( WithMetadata, invocation, ) -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, MaskField, MaskOutput @invocation( @@ -24,27 +21,10 @@ class AddConditioningMaskInvocation(BaseInvocation): """Add a mask to an existing conditioning tensor.""" conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") - image: ImageField = InputField( - description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. " - "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask." - ) - - @staticmethod - def convert_image_to_mask(image: Image.Image) -> torch.Tensor: - """Convert a PIL image to a uint8 mask tensor.""" - np_image = np.array(image) - torch_image = torch.from_numpy(np_image[:, :, 0]) - mask = torch_image >= 128 - return mask.to(dtype=torch.uint8) + mask: MaskField = InputField(description="A mask to add to the conditioning tensor.") def invoke(self, context: InvocationContext) -> ConditioningOutput: - image = context.services.images.get_pil_image(self.image.image_name) - mask = self.convert_image_to_mask(image) - - mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask" - context.services.latents.save(mask_name, mask) - - self.conditioning.mask_name = mask_name + self.conditioning.mask = self.mask return ConditioningOutput(conditioning=self.conditioning) @@ -56,33 +36,26 @@ class AddConditioningMaskInvocation(BaseInvocation): version="1.0.0", ) class RectangleMaskInvocation(BaseInvocation, WithMetadata): - """Create a mask image containing a rectangular mask region.""" + """Create a rectangular mask.""" - height: int = InputField(description="The height of the image.") - width: int = InputField(description="The width of the image.") - y_top: int = InputField(description="The top y-coordinate of the rectangle (inclusive).") - y_bottom: int = InputField(description="The bottom y-coordinate of the rectangle (exclusive).") - x_left: int = InputField(description="The left x-coordinate of the rectangle (inclusive).") - x_right: int = InputField(description="The right x-coordinate of the rectangle (exclusive).") + height: int = InputField(description="The height of the entire mask.") + width: int = InputField(description="The width of the entire mask.") + y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).") + x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).") + rectangle_height: int = InputField(description="The height of the rectangular masked region.") + rectangle_width: int = InputField(description="The width of the rectangular masked region.") - def invoke(self, context: InvocationContext) -> ImageOutput: - mask = np.zeros((self.height, self.width, 3), dtype=np.uint8) - mask[self.y_top : self.y_bottom, self.x_left : self.x_right, :] = 255 - mask_image = Image.fromarray(mask) + def invoke(self, context: InvocationContext) -> MaskOutput: + mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) + mask[ + :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width + ] = True - image_dto = context.services.images.create( - image=mask_image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, + mask_name = f"{context.graph_execution_state_id}__{self.id}_mask" + context.services.latents.save(mask_name, mask) + + return MaskOutput( + mask=MaskField(mask_name=mask_name), + width=self.width, + height=self.height, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index dbebc2ab82..7716457843 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -345,10 +345,9 @@ class DenoiseLatentsInvocation(BaseInvocation): positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) - mask_name = positive_conditioning.mask_name - mask = None - if mask_name is not None: - mask = context.services.latents.get(mask_name) + mask = positive_conditioning.mask + if mask is not None: + mask = context.services.latents.get(mask.mask_name) text_embeddings_masks.append(mask) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index bd07f3010b..6ec77dd0ec 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -233,6 +233,24 @@ class BoardField(BaseModel): board_id: str = Field(description="The id of the board") +class MaskField(BaseModel): + """A mask primitive field.""" + + mask_name: str = Field(description="The name of the mask.") + + +@invocation_output("mask_output") +class MaskOutput(BaseInvocationOutput): + """A torch mask tensor. + dtype: torch.bool + shape: (1, height, width). + """ + + mask: MaskField = OutputField(description="The mask.") + width: int = OutputField(description="The width of the mask in pixels.") + height: int = OutputField(description="The height of the mask in pixels.") + + @invocation_output("image_output") class ImageOutput(BaseInvocationOutput): """Base class for nodes that output a single image""" @@ -428,10 +446,10 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") - mask_name: Optional[str] = Field( + mask: Optional[MaskField] = Field( default=None, - description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included " - "regions should be set to 1.", + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to 1.", ) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index fb104d7bab..ca215a5714 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -20,6 +20,8 @@ class ExtraConditioningInfo: @dataclass class BasicConditioningInfo: + """SD 1/2 text conditioning information produced by Compel.""" + embeds: torch.Tensor extra_conditioning: Optional[ExtraConditioningInfo] @@ -30,6 +32,8 @@ class BasicConditioningInfo: @dataclass class SDXLConditioningInfo(BasicConditioningInfo): + """SDXL text conditioning information produced by Compel.""" + pooled_embeds: torch.Tensor add_time_ids: torch.Tensor diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py index 861ac128ea..2db4a24570 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_attention.py @@ -52,6 +52,9 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0): w //= 2 assert h * w == query_seq_len + # Convert the bool masks to float masks. + per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32) + # Apply max-pooling to resize the masks to the target spatial dimensions. # TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation # here. diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 7dbe3586e7..a8e663b49d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -313,17 +313,22 @@ class InvokeAIDiffuserComponent: def _preprocess_regional_prompt_mask( self, mask: Optional[torch.Tensor], target_height: int, target_width: int ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation. + + Returns: + torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width). + """ if mask is None: - # HACK(ryand): Figure out how to know the target device/dtype. - return torch.ones((1, 1, target_height, target_width), dtype=torch.float16, device="cuda") - else: - # HACK(ryand): It would make more sense to do NEAREST resising with an integer dtype, and probably on the - # CPU. - tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST - ) - mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w) - mask = tf(mask) + return torch.ones((1, 1, target_height, target_width), dtype=torch.bool) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + mask = tf(mask) return mask @@ -334,6 +339,19 @@ class InvokeAIDiffuserComponent: target_height: int, target_width: int, ) -> Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: + """Prepare text embeddings and associated masks for use in the UNet forward pass. + + - Concatenates the text embeddings into a single tensor (returned as a single BasicConditioningInfo or + SDXLConditioningInfo). + - Preprocesses the masks to match the target height and width, and stacks them into a single tensor. + - If all masks are None, skips all mask processing. + + Returns: + Tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[RegionalPromptData]]: + (text_embedding, regional_prompt_data) + - text_embedding: The concatenated text embeddings. + - regional_prompt_data: The processed masks and embedding ranges, or None if all masks are None. + """ is_sdxl = type(text_embeddings[0]) is SDXLConditioningInfo all_masks_are_none = all(mask is None for mask in masks) @@ -356,6 +374,10 @@ class InvokeAIDiffuserComponent: # We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids. # TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all # the conditioning info, then we shouldn't allow it to be passed in. + # How does Compel handle this? Options that come to mind: + # - Blend the pooled_embeds and add_time_ids from all of the text embeddings. + # - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since + # this is likely the global prompt. if pooled_embedding is None: pooled_embedding = text_embedding_info.pooled_embeds if add_time_ids is None: