Rename MaskField to be a generice TensorField.

This commit is contained in:
Ryan Dick 2024-04-08 14:16:22 -04:00 committed by Kent Keirsey
parent 5b5a4204a1
commit 338bf808d6
5 changed files with 18 additions and 18 deletions

View File

@ -10,8 +10,8 @@ from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
Input, Input,
InputField, InputField,
MaskField,
OutputField, OutputField,
TensorField,
UIComponent, UIComponent,
) )
from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.invocations.primitives import ConditioningOutput
@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip, description=FieldDescriptions.clip,
input=Input.Connection, input=Input.Connection,
) )
mask: Optional[MaskField] = InputField( mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )
@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
target_height: int = InputField(default=1024, description="") target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1") clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2") clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField( mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to." default=None, description="A mask defining the region that this conditioning prompt applies to."
) )

View File

@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel):
gradient: bool = Field(default=False, description="Used for gradient inpainting") gradient: bool = Field(default=False, description="Used for gradient inpainting")
class MaskField(BaseModel): class TensorField(BaseModel):
"""A mask primitive field.""" """A tensor primitive field."""
mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).") tensor_name: str = Field(description="The name of a tensor.")
class LatentsField(BaseModel): class LatentsField(BaseModel):
@ -232,9 +232,9 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value""" """A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor") conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[MaskField] = Field( mask: Optional[TensorField] = Field(
default=None, default=None,
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, " description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.", "included regions should be set to True.",
) )

View File

@ -380,7 +380,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask = cond.mask mask = cond.mask
if mask is not None: if mask is not None:
mask = context.tensors.load(mask.mask_name) mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask) text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks return text_embeddings, text_embeddings_masks

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation, invocation,
invocation_output, invocation_output,
) )
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput from invokeai.app.invocations.primitives import MaskOutput
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
True True
) )
mask_name = context.tensors.save(mask) mask_tensor_name = context.tensors.save(mask)
return MaskOutput( return MaskOutput(
mask=MaskField(mask_name=mask_name), mask=TensorField(tensor_name=mask_tensor_name),
width=self.width, width=self.width,
height=self.height, height=self.height,
) )
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
class PromptMaskPair(BaseModel): class PromptMaskPair(BaseModel):
prompt: str prompt: str
mask: MaskField mask: TensorField
default_prompt_color_pairs = [ default_prompt_color_pairs = [
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1) mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension. # Add explicit channel dimension.
mask = mask.unsqueeze(0) mask = mask.unsqueeze(0)
mask_name = context.tensors.save(mask) mask_tensor_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name))) prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs) return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
@invocation_output("split_mask_prompt_pair_output") @invocation_output("split_mask_prompt_pair_output")
class SplitMaskPromptPairOutput(BaseInvocationOutput): class SplitMaskPromptPairOutput(BaseInvocationOutput):
prompt: str = OutputField() prompt: str = OutputField()
mask: MaskField = OutputField() mask: TensorField = OutputField()
@invocation( @invocation(

View File

@ -14,8 +14,8 @@ from invokeai.app.invocations.fields import (
Input, Input,
InputField, InputField,
LatentsField, LatentsField,
MaskField,
OutputField, OutputField,
TensorField,
UIComponent, UIComponent,
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
@ -414,7 +414,7 @@ class ColorInvocation(BaseInvocation):
class MaskOutput(BaseInvocationOutput): class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor.""" """A torch mask tensor."""
mask: MaskField = OutputField(description="The mask.") mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.") width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.") height: int = OutputField(description="The height of the mask in pixels.")