mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Added node for creating mask inpaint
This commit is contained in:
parent
e9a294f733
commit
cfd827cfad
@ -25,6 +25,7 @@ from invokeai.app.invocations.primitives import (
|
|||||||
LatentsField,
|
LatentsField,
|
||||||
LatentsOutput,
|
LatentsOutput,
|
||||||
InpaintMaskField,
|
InpaintMaskField,
|
||||||
|
InpaintMaskOutput,
|
||||||
build_latents_output,
|
build_latents_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
@ -66,6 +67,76 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
|||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
|
|
||||||
|
|
||||||
|
@title("Create inpaint mask")
|
||||||
|
@tags("mask", "inpaint")
|
||||||
|
class CreateInpaintMaskInvocation(BaseInvocation):
|
||||||
|
"""Creates mask for inpaint model run."""
|
||||||
|
|
||||||
|
# Metadata
|
||||||
|
type: Literal["create_inpaint_mask"] = "create_inpaint_mask"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be inpainted")
|
||||||
|
mask: ImageField = InputField(description="The mask to use when pasting")
|
||||||
|
vae: VaeField = InputField(
|
||||||
|
description=FieldDescriptions.vae,
|
||||||
|
input=Input.Connection,
|
||||||
|
)
|
||||||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
|
def prep_mask_tensor(self, mask_image):
|
||||||
|
if mask_image.mode != "L":
|
||||||
|
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
||||||
|
mask_image = mask_image.convert("L")
|
||||||
|
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||||
|
if mask_tensor.dim() == 3:
|
||||||
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||||||
|
#if shape is not None:
|
||||||
|
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
||||||
|
return mask_tensor
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> InpaintMaskOutput:
|
||||||
|
if self.image is not None:
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
image = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image.dim() == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
image = None
|
||||||
|
|
||||||
|
mask = self.prep_mask_tensor(
|
||||||
|
context.services.images.get_pil_image(self.mask.image_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
vae_info = context.services.model_manager.get_model(
|
||||||
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
img_mask = tv_resize(mask, image.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||||
|
masked_image = image * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||||
|
# TODO:
|
||||||
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||||
|
|
||||||
|
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
||||||
|
context.services.latents.save(masked_latents_name, masked_latents)
|
||||||
|
else:
|
||||||
|
masked_latents_name = None
|
||||||
|
|
||||||
|
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||||
|
context.services.latents.save(mask_name, mask)
|
||||||
|
|
||||||
|
return InpaintMaskOutput(
|
||||||
|
inpaint_mask=InpaintMaskField(
|
||||||
|
mask_name=mask_name,
|
||||||
|
masked_latents_name=masked_latents_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
scheduler_info: ModelInfo,
|
scheduler_info: ModelInfo,
|
||||||
@ -340,19 +411,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return num_inference_steps, timesteps, init_timestep
|
return num_inference_steps, timesteps, init_timestep
|
||||||
|
|
||||||
def prep_mask_tensor(self, mask, context, lantents):
|
def prep_inpaint_mask(self, context, latents):
|
||||||
if mask is None:
|
if self.mask is None:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
mask_image = context.services.images.get_pil_image(mask.image_name)
|
mask = context.services.latents.get(self.mask.mask_name)
|
||||||
if mask_image.mode != "L":
|
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR)
|
||||||
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
|
if self.mask.masked_latents_name is not None:
|
||||||
mask_image = mask_image.convert("L")
|
masked_latents = context.services.latents.get(self.mask.masked_latents_name)
|
||||||
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
else:
|
||||||
if mask_tensor.dim() == 3:
|
masked_latents = None
|
||||||
mask_tensor = mask_tensor.unsqueeze(0)
|
|
||||||
mask_tensor = tv_resize(mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR)
|
return 1 - mask, masked_latents
|
||||||
return 1 - mask_tensor
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
@ -373,7 +443,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if seed is None:
|
if seed is None:
|
||||||
seed = 0
|
seed = 0
|
||||||
|
|
||||||
mask = self.prep_mask_tensor(self.mask, context, latents)
|
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||||
|
|
||||||
# Get the source node id (we are invoking the prepared node)
|
# Get the source node id (we are invoking the prepared node)
|
||||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||||
@ -404,6 +474,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
if masked_latents is not None:
|
||||||
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
@ -440,6 +512,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
noise=noise,
|
noise=noise,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
masked_latents=masked_latents,
|
||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=control_data, # list[ControlNetData]
|
control_data=control_data, # list[ControlNetData]
|
||||||
@ -661,26 +734,11 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||||
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32)
|
||||||
|
|
||||||
@torch.no_grad()
|
@staticmethod
|
||||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
def vae_encode(vae_info, upcast, tiled, image_tensor):
|
||||||
# image = context.services.images.get(
|
|
||||||
# self.image.image_type, self.image.image_name
|
|
||||||
# )
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
|
||||||
|
|
||||||
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
|
||||||
vae_info = context.services.model_manager.get_model(
|
|
||||||
**self.vae.vae.dict(),
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
|
||||||
if image_tensor.dim() == 3:
|
|
||||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
|
||||||
|
|
||||||
with vae_info as vae:
|
with vae_info as vae:
|
||||||
orig_dtype = vae.dtype
|
orig_dtype = vae.dtype
|
||||||
if self.fp32:
|
if upcast:
|
||||||
vae.to(dtype=torch.float32)
|
vae.to(dtype=torch.float32)
|
||||||
|
|
||||||
use_torch_2_0_or_xformers = isinstance(
|
use_torch_2_0_or_xformers = isinstance(
|
||||||
@ -705,7 +763,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
vae.to(dtype=torch.float16)
|
vae.to(dtype=torch.float16)
|
||||||
# latents = latents.half()
|
# latents = latents.half()
|
||||||
|
|
||||||
if self.tiled:
|
if tiled:
|
||||||
vae.enable_tiling()
|
vae.enable_tiling()
|
||||||
else:
|
else:
|
||||||
vae.disable_tiling()
|
vae.disable_tiling()
|
||||||
@ -719,6 +777,27 @@ class ImageToLatentsInvocation(BaseInvocation):
|
|||||||
latents = vae.config.scaling_factor * latents
|
latents = vae.config.scaling_factor * latents
|
||||||
latents = latents.to(dtype=orig_dtype)
|
latents = latents.to(dtype=orig_dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
# image = context.services.images.get(
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# )
|
||||||
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
|
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
|
||||||
|
vae_info = context.services.model_manager.get_model(
|
||||||
|
**self.vae.vae.dict(),
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||||
|
if image_tensor.dim() == 3:
|
||||||
|
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||||
|
|
||||||
|
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
latents = latents.to("cpu")
|
latents = latents.to("cpu")
|
||||||
context.services.latents.save(name, latents)
|
context.services.latents.save(name, latents)
|
||||||
|
@ -314,7 +314,14 @@ class InpaintMaskField(BaseModel):
|
|||||||
"""An inpaint mask field"""
|
"""An inpaint mask field"""
|
||||||
|
|
||||||
mask_name: str = Field(description="The name of the mask image")
|
mask_name: str = Field(description="The name of the mask image")
|
||||||
masked_latens_name: Optional[str] = Field(description="The name of the masked image latents")
|
masked_latents_name: Optional[str] = Field(description="The name of the masked image latents")
|
||||||
|
|
||||||
|
|
||||||
|
class InpaintMaskOutput(BaseInvocationOutput):
|
||||||
|
"""Base class for nodes that output a single image"""
|
||||||
|
|
||||||
|
type: Literal["inpaint_mask_output"] = "inpaint_mask_output"
|
||||||
|
inpaint_mask: InpaintMaskField = OutputField(description="Mask for inpaint model run")
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
@ -342,6 +342,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
|
||||||
if init_timestep.shape[0] == 0:
|
if init_timestep.shape[0] == 0:
|
||||||
@ -375,11 +376,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_inpainting_model(self.unet):
|
if is_inpainting_model(self.unet):
|
||||||
# You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
|
if masked_latents is None:
|
||||||
# (that's why there's a mask!) but it seems to really want that blanked out.
|
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||||
masked_latents = orig_latents * torch.where(mask < 0.5, 1, 0)
|
|
||||||
|
|
||||||
# TODO: we should probably pass this in so we don't have to try/finally around setting it.
|
|
||||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||||
self._unet_forward, mask, masked_latents
|
self._unet_forward, mask, masked_latents
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user