From 58277c6adad53ef64e653125fcbdb9799127b7de Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 13 Feb 2024 14:24:46 -0500 Subject: [PATCH] Add a mask to the ConditioningField primitive type. --- invokeai/app/invocations/conditioning.py | 41 ++++++++++++++++++++++++ invokeai/app/invocations/primitives.py | 5 +++ 2 files changed, 46 insertions(+) create mode 100644 invokeai/app/invocations/conditioning.py diff --git a/invokeai/app/invocations/conditioning.py b/invokeai/app/invocations/conditioning.py new file mode 100644 index 0000000000..9579d80009 --- /dev/null +++ b/invokeai/app/invocations/conditioning.py @@ -0,0 +1,41 @@ +import numpy as np +import torch +from PIL.Image import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, InvocationContext, invocation +from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField + + +@invocation( + "add_conditioning_mask", + title="Add Conditioning Mask", + tags=["conditioning"], + category="conditioning", + version="1.0.0", +) +class AddConditioningMaskInvocation(BaseInvocation): + """Add a mask to an existing conditioning tensor.""" + + conditioning: ConditioningField = InputField(description="The conditioning tensor to add a mask to.") + image: ImageField = InputField( + description="A mask image to add to the conditioning tensor. Only the first channel of the image is used. " + "Pixels <128 are excluded from the mask, pixels >=128 are included in the mask." + ) + + @staticmethod + def convert_image_to_mask(image: Image) -> torch.Tensor: + """Convert a PIL image to a uint8 mask tensor.""" + np_image = np.array(image) + torch_image = torch.from_numpy(np_image[0, :, :]) + mask = torch_image >= 128 + return mask.to(dtype=torch.uint8) + + def invoke(self, context: InvocationContext) -> ConditioningOutput: + image = context.services.images.get_pil_image(self.image.image_name) + mask = self.convert_image_to_mask(image) + + mask_name = f"{context.graph_execution_state_id}__{self.id}_conditioning_mask" + context.services.latents.save(mask_name, mask) + + self.conditioning.mask_name = mask_name + return ConditioningOutput(conditioning=self.conditioning) diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index afe8ff06d9..bd07f3010b 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -428,6 +428,11 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") + mask_name: Optional[str] = Field( + default=None, + description="The mask associated with this conditioning tensor. Excluded regions should be set to 0, included " + "regions should be set to 1.", + ) @invocation_output("conditioning_output")