mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
rename: Inpaint Mask to Denoise Mask
This commit is contained in:
@ -21,10 +21,10 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
InpaintMaskField,
|
||||
InpaintMaskOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
@ -57,16 +57,16 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
|
||||
|
||||
@title("Create Inpaint Mask")
|
||||
@tags("mask", "inpaint")
|
||||
class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for inpaint model run."""
|
||||
@title("Create Denoise Mask")
|
||||
@tags("mask", "denoise")
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
# Metadata
|
||||
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
|
||||
type: Literal["create_denoise_mask"] = "create_denoise_mask"
|
||||
|
||||
# Inputs
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
|
||||
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked")
|
||||
mask: ImageField = InputField(description="The mask to use when pasting")
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
@ -86,7 +86,7 @@ class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
return mask_tensor
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
@ -118,8 +118,8 @@ class CreateInpaintMaskInvocation(BaseInvocation):
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
return InpaintMaskOutput(
|
||||
inpaint_mask=InpaintMaskField(
|
||||
return DenoiseMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
),
|
||||
@ -189,7 +189,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
default=None, description=FieldDescriptions.control, input=Input.Connection, ui_order=5
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
mask: Optional[InpaintMaskField] = InputField(
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.mask,
|
||||
)
|
||||
@ -403,13 +403,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
return num_inference_steps, timesteps, init_timestep
|
||||
|
||||
def prep_inpaint_mask(self, context, latents):
|
||||
if self.mask is None:
|
||||
if self.denoise_mask is None:
|
||||
return None, None
|
||||
|
||||
mask = context.services.latents.get(self.mask.mask_name)
|
||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||
if self.mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.mask.masked_latents_name)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
|
@ -296,21 +296,21 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
|
||||
# endregion
|
||||
|
||||
# region InpaintMask
|
||||
# region DenoiseMask
|
||||
|
||||
|
||||
class InpaintMaskField(BaseModel):
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
||||
|
||||
|
||||
class InpaintMaskOutput(BaseInvocationOutput):
|
||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
|
||||
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
|
||||
type: Literal["denoise_mask_output"] = "denoise_mask_output"
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
|
||||
|
||||
# endregion
|
||||
|
Reference in New Issue
Block a user