mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename MaskField to be a generice TensorField.
This commit is contained in:
parent
5b5a4204a1
commit
338bf808d6
@ -10,8 +10,8 @@ from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
)
|
||||
mask: Optional[MaskField] = InputField(
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
target_height: int = InputField(default=1024, description="")
|
||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||
mask: Optional[MaskField] = InputField(
|
||||
mask: Optional[TensorField] = InputField(
|
||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||
)
|
||||
|
||||
|
@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel):
|
||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||
|
||||
|
||||
class MaskField(BaseModel):
|
||||
"""A mask primitive field."""
|
||||
class TensorField(BaseModel):
|
||||
"""A tensor primitive field."""
|
||||
|
||||
mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).")
|
||||
tensor_name: str = Field(description="The name of a tensor.")
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
@ -232,9 +232,9 @@ class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
mask: Optional[MaskField] = Field(
|
||||
mask: Optional[TensorField] = Field(
|
||||
default=None,
|
||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||
"included regions should be set to True.",
|
||||
)
|
||||
|
||||
|
@ -380,7 +380,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.mask_name)
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
text_embeddings_masks.append(mask)
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
|
||||
|
||||
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
True
|
||||
)
|
||||
|
||||
mask_name = context.tensors.save(mask)
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
mask=MaskField(mask_name=mask_name),
|
||||
mask=TensorField(tensor_name=mask_tensor_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
|
||||
|
||||
class PromptMaskPair(BaseModel):
|
||||
prompt: str
|
||||
mask: MaskField
|
||||
mask: TensorField
|
||||
|
||||
|
||||
default_prompt_color_pairs = [
|
||||
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
|
||||
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
||||
# Add explicit channel dimension.
|
||||
mask = mask.unsqueeze(0)
|
||||
mask_name = context.tensors.save(mask)
|
||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
|
||||
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
||||
|
||||
|
||||
@invocation_output("split_mask_prompt_pair_output")
|
||||
class SplitMaskPromptPairOutput(BaseInvocationOutput):
|
||||
prompt: str = OutputField()
|
||||
mask: MaskField = OutputField()
|
||||
mask: TensorField = OutputField()
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -14,8 +14,8 @@ from invokeai.app.invocations.fields import (
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
MaskField,
|
||||
OutputField,
|
||||
TensorField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
@ -414,7 +414,7 @@ class ColorInvocation(BaseInvocation):
|
||||
class MaskOutput(BaseInvocationOutput):
|
||||
"""A torch mask tensor."""
|
||||
|
||||
mask: MaskField = OutputField(description="The mask.")
|
||||
mask: TensorField = OutputField(description="The mask.")
|
||||
width: int = OutputField(description="The width of the mask in pixels.")
|
||||
height: int = OutputField(description="The height of the mask in pixels.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user