rename: Inpaint Mask to Denoise Mask

This commit is contained in:
blessedcoolant
2023-08-27 05:50:13 +12:00
parent 226721ce51
commit c923d094c6
15 changed files with 137 additions and 137 deletions

View File

@ -21,10 +21,10 @@ from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
DenoiseMaskOutput,
ImageField,
ImageOutput,
InpaintMaskField,
InpaintMaskOutput,
LatentsField,
LatentsOutput,
build_latents_output,
@ -57,16 +57,16 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create Inpaint Mask")
@tags("mask", "inpaint")
class CreateInpaintMaskInvocation(BaseInvocation):
"""Creates mask for inpaint model run."""
@title("Create Denoise Mask")
@tags("mask", "denoise")
class CreateDenoiseMaskInvocation(BaseInvocation):
"""Creates mask for denoising model run."""
# Metadata
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
type: Literal["create_denoise_mask"] = "create_denoise_mask"
# Inputs
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked")
mask: ImageField = InputField(description="The mask to use when pasting")
vae: VaeField = InputField(
description=FieldDescriptions.vae,
@ -86,7 +86,7 @@ class CreateInpaintMaskInvocation(BaseInvocation):
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
@ -118,8 +118,8 @@ class CreateInpaintMaskInvocation(BaseInvocation):
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)
return InpaintMaskOutput(
inpaint_mask=InpaintMaskField(
return DenoiseMaskOutput(
denoise_mask=DenoiseMaskField(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
),
@ -189,7 +189,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
mask: Optional[InpaintMaskField] = InputField(
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None,
description=FieldDescriptions.mask,
)
@ -403,13 +403,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep
def prep_inpaint_mask(self, context, latents):
if self.mask is None:
if self.denoise_mask is None:
return None, None
mask = context.services.latents.get(self.mask.mask_name)
mask = context.services.latents.get(self.denoise_mask.mask_name)
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
if self.mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.mask.masked_latents_name)
if self.denoise_mask.masked_latents_name is not None:
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
else:
masked_latents = None

View File

@ -296,21 +296,21 @@ class ImageCollectionInvocation(BaseInvocation):
# endregion
# region InpaintMask
# region DenoiseMask
class InpaintMaskField(BaseModel):
class DenoiseMaskField(BaseModel):
"""An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image")
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
class InpaintMaskOutput(BaseInvocationOutput):
class DenoiseMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
type: Literal["denoise_mask_output"] = "denoise_mask_output"
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
# endregion