mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rename MaskField to be a generice TensorField.
This commit is contained in:
@ -9,7 +9,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, MaskField, OutputField, WithMetadata
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField, InputField, OutputField, TensorField, WithMetadata
|
||||
from invokeai.app.invocations.primitives import MaskOutput
|
||||
|
||||
|
||||
@ -36,9 +36,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
True
|
||||
)
|
||||
|
||||
mask_name = context.tensors.save(mask)
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
mask=MaskField(mask_name=mask_name),
|
||||
mask=TensorField(tensor_name=mask_tensor_name),
|
||||
width=self.width,
|
||||
height=self.height,
|
||||
)
|
||||
@ -51,7 +51,7 @@ class PromptColorPair(BaseModel):
|
||||
|
||||
class PromptMaskPair(BaseModel):
|
||||
prompt: str
|
||||
mask: MaskField
|
||||
mask: TensorField
|
||||
|
||||
|
||||
default_prompt_color_pairs = [
|
||||
@ -92,15 +92,15 @@ class ExtractMasksAndPromptsInvocation(BaseInvocation):
|
||||
mask = torch.all(image_as_tensor == torch.tensor(pair.color.tuple()), dim=-1)
|
||||
# Add explicit channel dimension.
|
||||
mask = mask.unsqueeze(0)
|
||||
mask_name = context.tensors.save(mask)
|
||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=MaskField(mask_name=mask_name)))
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
prompt_mask_pairs.append(PromptMaskPair(prompt=pair.prompt, mask=TensorField(tensor_name=mask_tensor_name)))
|
||||
return ExtractMasksAndPromptsOutput(prompt_mask_pairs=prompt_mask_pairs)
|
||||
|
||||
|
||||
@invocation_output("split_mask_prompt_pair_output")
|
||||
class SplitMaskPromptPairOutput(BaseInvocationOutput):
|
||||
prompt: str = OutputField()
|
||||
mask: MaskField = OutputField()
|
||||
mask: TensorField = OutputField()
|
||||
|
||||
|
||||
@invocation(
|
||||
|
Reference in New Issue
Block a user