From 338bf808d64cd99be3090a68f93f158600149741 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 14:16:22 -0400 Subject: [PATCH] Rename MaskField to be a generice TensorField. --- invokeai/app/invocations/compel.py | 6 +++--- invokeai/app/invocations/fields.py | 10 +++++----- invokeai/app/invocations/latent.py | 2 +- invokeai/app/invocations/mask.py | 14 +++++++------- invokeai/app/invocations/primitives.py | 4 ++-- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 6df3301362..92012691ea 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -10,8 +10,8 @@ from invokeai.app.invocations.fields import ( FieldDescriptions, Input, InputField, - MaskField, OutputField, + TensorField, UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput @@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation): description=FieldDescriptions.clip, input=Input.Connection, ) - mask: Optional[MaskField] = InputField( + mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) @@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): target_height: int = InputField(default=1024, description="") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") - mask: Optional[MaskField] = InputField( + mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 56b9e12a6c..0fa0216f1c 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel): gradient: bool = Field(default=False, description="Used for gradient inpainting") -class MaskField(BaseModel): - """A mask primitive field.""" +class TensorField(BaseModel): + """A tensor primitive field.""" - mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).") + tensor_name: str = Field(description="The name of a tensor.") class LatentsField(BaseModel): @@ -232,9 +232,9 @@ class ConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") - mask: Optional[MaskField] = Field( + mask: Optional[TensorField] = Field( default=None, - description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " "included regions should be set to True.", ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index db7cd20172..3070cd1e70 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -380,7 +380,7 @@ class DenoiseLatentsInvocation(BaseInvocation): mask = cond.mask if mask is not None: - mask = context.tensors.load(mask.mask_name) + mask = context.tensors.load(mask.tensor_name) text_embeddings_masks.append(mask) return text_embeddings, text_embeddings_masks diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 31eb70e056..de4887e20d 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata +from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata from invokeai.app.invocations.primitives import MaskOutput @@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): True ) - mask_name = context.tensors.save(mask) + mask_tensor_name = context.tensors.save(mask) return MaskOutput( - mask=MaskField(mask_name=mask_name), + mask=TensorField(tensor_name=mask_tensor_name), width=self.width, height=self.height, ) @@ -51,7 +51,7 @@ class PromptColorPair(BaseModel): class PromptMaskPair(BaseModel): prompt: str - mask: MaskField + mask: TensorField default_prompt_color_pairs = [ @@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation): mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) # Add explicit channel dimension. mask = mask.unsqueeze(0) - mask_name = context.tensors.save(mask) - prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) + mask_tensor_name = context.tensors.save(mask) + prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name))) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) @invocation_output("split_mask_prompt_pair_output") class SplitMaskPromptPairOutput(BaseInvocationOutput): prompt: str = OutputField() - mask: MaskField = OutputField() + mask: TensorField = OutputField() @invocation( diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 25930f7d00..28f72fb377 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -14,8 +14,8 @@ from invokeai.app.invocations.fields import ( Input, InputField, LatentsField, - MaskField, OutputField, + TensorField, UIComponent, ) from invokeai.app.services.images.images_common import ImageDTO @@ -414,7 +414,7 @@ class ColorInvocation(BaseInvocation): class MaskOutput(BaseInvocationOutput): """A torch mask tensor.""" - mask: MaskField = OutputField(description="The mask.") + mask: TensorField = OutputField(description="The mask.") width: int = OutputField(description="The width of the mask in pixels.") height: int = OutputField(description="The height of the mask in pixels.")