InvokeAI/invokeai/app/invocations/mask.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

37 lines
1.5 KiB
Python
Raw Normal View History

2024-03-08 15:30:55 +00:00
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, InvocationContext, invocation
from invokeai.app.invocations.fields import InputField, TensorField, WithMetadata
2024-03-08 15:30:55 +00:00
from invokeai.app.invocations.primitives import MaskOutput
@invocation(
"rectangle_mask",
title="Create Rectangle Mask",
tags=["conditioning"],
category="conditioning",
version="1.0.1",
2024-03-08 15:30:55 +00:00
)
class RectangleMaskInvocation(BaseInvocation, WithMetadata):
"""Create a rectangular mask."""
width: int = InputField(description="The width of the entire mask.")
height: int = InputField(description="The height of the entire mask.")
2024-03-08 15:30:55 +00:00
x_left: int = InputField(description="The left x-coordinate of the rectangular masked region (inclusive).")
y_top: int = InputField(description="The top y-coordinate of the rectangular masked region (inclusive).")
2024-03-08 15:30:55 +00:00
rectangle_width: int = InputField(description="The width of the rectangular masked region.")
rectangle_height: int = InputField(description="The height of the rectangular masked region.")
2024-03-08 15:30:55 +00:00
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
2024-04-10 19:49:27 +00:00
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
True
)
2024-03-08 15:30:55 +00:00
mask_tensor_name = context.tensors.save(mask)
2024-03-08 15:30:55 +00:00
return MaskOutput(
mask=TensorField(tensor_name=mask_tensor_name),
2024-03-08 15:30:55 +00:00
width=self.width,
height=self.height,
)