Rename MaskField to be a generice TensorField.

This commit is contained in:
Ryan Dick 2024-04-08 14:16:22 -04:00 committed by Kent Keirsey
parent 5b5a4204a1
commit 338bf808d6
5 changed files with 18 additions and 18 deletions

View File

@ -10,8 +10,8 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
MaskField,
OutputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[MaskField] = InputField(
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[MaskField] = InputField(
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)

View File

@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel):
gradient: bool = Field(default=False, description="Used for gradient inpainting")
class MaskField(BaseModel):
"""A mask primitive field."""
class TensorField(BaseModel):
"""A tensor primitive field."""
mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).")
tensor_name: str = Field(description="The name of a tensor.")
class LatentsField(BaseModel):
@ -232,9 +232,9 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[MaskField] = Field(
mask: Optional[TensorField] = Field(
default=None,
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)

View File

@ -380,7 +380,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.mask_name)
mask = context.tensors.load(mask.tensor_name)
text_embeddings_masks.append(mask)
return text_embeddings, text_embeddings_masks

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
True
)
mask_name = context.tensors.save(mask)
mask_tensor_name = context.tensors.save(mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
mask=TensorField(tensor_name=mask_tensor_name),
width=self.width,
height=self.height,
)
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
class PromptMaskPair(BaseModel):
prompt: str
mask: MaskField
mask: TensorField
default_prompt_color_pairs = [
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension.
mask = mask.unsqueeze(0)
mask_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
mask_tensor_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
@invocation_output("split_mask_prompt_pair_output")
class SplitMaskPromptPairOutput(BaseInvocationOutput):
prompt: str = OutputField()
mask: MaskField = OutputField()
mask: TensorField = OutputField()
@invocation(

View File

@ -14,8 +14,8 @@ from invokeai.app.invocations.fields import (
Input,
InputField,
LatentsField,
MaskField,
OutputField,
TensorField,
UIComponent,
)
from invokeai.app.services.images.images_common import ImageDTO
@ -414,7 +414,7 @@ class ColorInvocation(BaseInvocation):
class MaskOutput(BaseInvocationOutput):
"""A torch mask tensor."""
mask: MaskField = OutputField(description="The mask.")
mask: TensorField = OutputField(description="The mask.")
width: int = OutputField(description="The width of the mask in pixels.")
height: int = OutputField(description="The height of the mask in pixels.")