Added node for creating mask inpaint

This commit is contained in:
Sergey Borisov 2023-08-18 04:07:40 +03:00
parent e9a294f733
commit cfd827cfad
3 changed files with 122 additions and 37 deletions

View File

@ -25,6 +25,7 @@ from invokeai.app.invocations.primitives import (
LatentsField, LatentsField,
LatentsOutput, LatentsOutput,
InpaintMaskField, InpaintMaskField,
InpaintMaskOutput,
build_latents_output, build_latents_output,
) )
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
@ -66,6 +67,76 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@title("Create inpaint mask")
@tags("mask", "inpaint")
class CreateInpaintMaskInvocation(BaseInvocation):
"""Creates mask for inpaint model run."""
# Metadata
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
# Inputs
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
mask: ImageField = InputField(description="The mask to use when pasting")
vae: VaeField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
def prep_mask_tensor(self, mask_image):
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
if mask_tensor.dim() == 3:
mask_tensor = mask_tensor.unsqueeze(0)
#if shape is not None:
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
if self.image is not None:
image = context.services.images.get_pil_image(self.image.image_name)
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image.dim() == 3:
image = image.unsqueeze(0)
else:
image = None
mask = self.prep_mask_tensor(
context.services.images.get_pil_image(self.mask.image_name),
)
if image is not None:
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR)
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
context.services.latents.save(masked_latents_name, masked_latents)
else:
masked_latents_name = None
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
context.services.latents.save(mask_name, mask)
return InpaintMaskOutput(
inpaint_mask=InpaintMaskField(
mask_name=mask_name,
masked_latents_name=masked_latents_name,
),
)
def get_scheduler( def get_scheduler(
context: InvocationContext, context: InvocationContext,
scheduler_info: ModelInfo, scheduler_info: ModelInfo,
@ -340,19 +411,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
return num_inference_steps, timesteps, init_timestep return num_inference_steps, timesteps, init_timestep
def prep_mask_tensor(self, mask, context, lantents): def prep_inpaint_mask(self, context, latents):
if mask is None: if self.mask is None:
return None return None, None
mask_image = context.services.images.get_pil_image(mask.image_name) mask = context.services.latents.get(self.mask.mask_name)
if mask_image.mode != "L": mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
# FIXME: why do we get passed an RGB image here? We can only use single-channel. if self.mask.masked_latents_name is not None:
mask_image = mask_image.convert("L") masked_latents = context.services.latents.get(self.mask.masked_latents_name)
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False) else:
if mask_tensor.dim() == 3: masked_latents = None
mask_tensor = mask_tensor.unsqueeze(0)
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR) return 1 - mask, masked_latents
return 1 - mask_tensor
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
@ -373,7 +443,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if seed is None: if seed is None:
seed = 0 seed = 0
mask = self.prep_mask_tensor(self.mask, context, latents) mask, masked_latents = self.prep_inpaint_mask(context, latents)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
@ -404,6 +474,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
if mask is not None: if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype) mask = mask.to(device=unet.device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -440,6 +512,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
noise=noise, noise=noise,
seed=seed, seed=seed,
mask=mask, mask=mask,
masked_latents=masked_latents,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
@ -661,26 +734,11 @@ class ImageToLatentsInvocation(BaseInvocation):
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
@torch.no_grad() @staticmethod
def invoke(self, context: InvocationContext) -> LatentsOutput: def vae_encode(vae_info, upcast, tiled, image_tensor):
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name)
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
with vae_info as vae: with vae_info as vae:
orig_dtype = vae.dtype orig_dtype = vae.dtype
if self.fp32: if upcast:
vae.to(dtype=torch.float32) vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
@ -705,7 +763,7 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.to(dtype=torch.float16) vae.to(dtype=torch.float16)
# latents = latents.half() # latents = latents.half()
if self.tiled: if tiled:
vae.enable_tiling() vae.enable_tiling()
else: else:
vae.disable_tiling() vae.disable_tiling()
@ -719,6 +777,27 @@ class ImageToLatentsInvocation(BaseInvocation):
latents = vae.config.scaling_factor * latents latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype) latents = latents.to(dtype=orig_dtype)
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(self.image.image_name)
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
latents = latents.to("cpu") latents = latents.to("cpu")
context.services.latents.save(name, latents) context.services.latents.save(name, latents)

View File

@ -314,7 +314,14 @@ class InpaintMaskField(BaseModel):
"""An inpaint mask field""" """An inpaint mask field"""
mask_name: str = Field(description="The name of the mask image") mask_name: str = Field(description="The name of the mask image")
masked_latens_name: Optional[str] = Field(description="The name of the masked image latents") masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
class InpaintMaskOutput(BaseInvocationOutput):
"""Base class for nodes that output a single image"""
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
# endregion # endregion

View File

@ -342,6 +342,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]: ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
if init_timestep.shape[0] == 0: if init_timestep.shape[0] == 0:
@ -375,11 +376,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
) )
if is_inpainting_model(self.unet): if is_inpainting_model(self.unet):
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint if masked_latents is None:
# (that's why there's a mask!) but it seems to really want that blanked out. raise Exception("Source image required for inpaint mask when inpaint model used!")
masked_latents = orig_latents * torch.where(mask < 0.5, 1, 0)
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents( self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
self._unet_forward, mask, masked_latents self._unet_forward, mask, masked_latents
) )