Rename MaskField to be a generice TensorField.

This commit is contained in:
Ryan Dick
2024-04-08 14:16:22 -04:00
committed by Kent Keirsey
parent 5b5a4204a1
commit 338bf808d6
5 changed files with 18 additions and 18 deletions

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
from invokeai.app.invocations.primitives import MaskOutput
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
True
)
mask_name = context.tensors.save(mask)
mask_tensor_name = context.tensors.save(mask)
return MaskOutput(
mask=MaskField(mask_name=mask_name),
mask=TensorField(tensor_name=mask_tensor_name),
width=self.width,
height=self.height,
)
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
class PromptMaskPair(BaseModel):
prompt: str
mask: MaskField
mask: TensorField
default_prompt_color_pairs = [
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
# Add explicit channel dimension.
mask = mask.unsqueeze(0)
mask_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
mask_tensor_name = context.tensors.save(mask)
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
@invocation_output("split_mask_prompt_pair_output")
class SplitMaskPromptPairOutput(BaseInvocationOutput):
prompt: str = OutputField()
mask: MaskField = OutputField()
mask: TensorField = OutputField()
@invocation(