mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename MaskField to be a generice TensorField.
This commit is contained in:
parent
5b5a4204a1
commit
338bf808d6
@ -10,8 +10,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
MaskField,
|
|
||||||
OutputField,
|
OutputField,
|
||||||
|
TensorField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||||
@ -59,7 +59,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
description=FieldDescriptions.clip,
|
description=FieldDescriptions.clip,
|
||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
)
|
)
|
||||||
mask: Optional[MaskField] = InputField(
|
mask: Optional[TensorField] = InputField(
|
||||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -270,7 +270,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
target_height: int = InputField(default=1024, description="")
|
target_height: int = InputField(default=1024, description="")
|
||||||
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
|
||||||
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
|
||||||
mask: Optional[MaskField] = InputField(
|
mask: Optional[TensorField] = InputField(
|
||||||
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
default=None, description="A mask defining the region that this conditioning prompt applies to."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -203,10 +203,10 @@ class DenoiseMaskField(BaseModel):
|
|||||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||||
|
|
||||||
|
|
||||||
class MaskField(BaseModel):
|
class TensorField(BaseModel):
|
||||||
"""A mask primitive field."""
|
"""A tensor primitive field."""
|
||||||
|
|
||||||
mask_name: str = Field(description="The name of a spatial mask. dtype: bool, shape: (1, h, w).")
|
tensor_name: str = Field(description="The name of a tensor.")
|
||||||
|
|
||||||
|
|
||||||
class LatentsField(BaseModel):
|
class LatentsField(BaseModel):
|
||||||
@ -232,9 +232,9 @@ class ConditioningField(BaseModel):
|
|||||||
"""A conditioning tensor primitive value"""
|
"""A conditioning tensor primitive value"""
|
||||||
|
|
||||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||||
mask: Optional[MaskField] = Field(
|
mask: Optional[TensorField] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The bool mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
|
||||||
"included regions should be set to True.",
|
"included regions should be set to True.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -380,7 +380,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
mask = cond.mask
|
mask = cond.mask
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = context.tensors.load(mask.mask_name)
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
text_embeddings_masks.append(mask)
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
return text_embeddings, text_embeddings_masks
|
return text_embeddings, text_embeddings_masks
|
||||||
|
@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata
|
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
|
||||||
from invokeai.app.invocations.primitives import MaskOutput
|
from invokeai.app.invocations.primitives import MaskOutput
|
||||||
|
|
||||||
|
|
||||||
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
True
|
True
|
||||||
)
|
)
|
||||||
|
|
||||||
mask_name = context.tensors.save(mask)
|
mask_tensor_name = context.tensors.save(mask)
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
mask=MaskField(mask_name=mask_name),
|
mask=TensorField(tensor_name=mask_tensor_name),
|
||||||
width=self.width,
|
width=self.width,
|
||||||
height=self.height,
|
height=self.height,
|
||||||
)
|
)
|
||||||
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
|
|||||||
|
|
||||||
class PromptMaskPair(BaseModel):
|
class PromptMaskPair(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
mask: MaskField
|
mask: TensorField
|
||||||
|
|
||||||
|
|
||||||
default_prompt_color_pairs = [
|
default_prompt_color_pairs = [
|
||||||
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
|
|||||||
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
||||||
# Add explicit channel dimension.
|
# Add explicit channel dimension.
|
||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
mask_name = context.tensors.save(mask)
|
mask_tensor_name = context.tensors.save(mask)
|
||||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
|
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
|
||||||
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("split_mask_prompt_pair_output")
|
@invocation_output("split_mask_prompt_pair_output")
|
||||||
class SplitMaskPromptPairOutput(BaseInvocationOutput):
|
class SplitMaskPromptPairOutput(BaseInvocationOutput):
|
||||||
prompt: str = OutputField()
|
prompt: str = OutputField()
|
||||||
mask: MaskField = OutputField()
|
mask: TensorField = OutputField()
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -14,8 +14,8 @@ from invokeai.app.invocations.fields import (
|
|||||||
Input,
|
Input,
|
||||||
InputField,
|
InputField,
|
||||||
LatentsField,
|
LatentsField,
|
||||||
MaskField,
|
|
||||||
OutputField,
|
OutputField,
|
||||||
|
TensorField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
@ -414,7 +414,7 @@ class ColorInvocation(BaseInvocation):
|
|||||||
class MaskOutput(BaseInvocationOutput):
|
class MaskOutput(BaseInvocationOutput):
|
||||||
"""A torch mask tensor."""
|
"""A torch mask tensor."""
|
||||||
|
|
||||||
mask: MaskField = OutputField(description="The mask.")
|
mask: TensorField = OutputField(description="The mask.")
|
||||||
width: int = OutputField(description="The width of the mask in pixels.")
|
width: int = OutputField(description="The width of the mask in pixels.")
|
||||||
height: int = OutputField(description="The height of the mask in pixels.")
|
height: int = OutputField(description="The height of the mask in pixels.")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user