mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
81 lines
3.3 KiB
Python
81 lines
3.3 KiB
Python
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torchvision.transforms as T
|
||
|
from PIL import Image
|
||
|
from torchvision.transforms.functional import resize as tv_resize
|
||
|
|
||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||
|
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField
|
||
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||
|
from invokeai.app.invocations.latent import DEFAULT_PRECISION
|
||
|
from invokeai.app.invocations.model import VAEField
|
||
|
from invokeai.app.invocations.primitives import DenoiseMaskOutput
|
||
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||
|
|
||
|
|
||
|
@invocation(
|
||
|
"create_denoise_mask",
|
||
|
title="Create Denoise Mask",
|
||
|
tags=["mask", "denoise"],
|
||
|
category="latents",
|
||
|
version="1.0.2",
|
||
|
)
|
||
|
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||
|
"""Creates mask for denoising model run."""
|
||
|
|
||
|
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection, ui_order=0)
|
||
|
image: Optional[ImageField] = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||
|
mask: ImageField = InputField(description="The mask to use when pasting", ui_order=2)
|
||
|
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled, ui_order=3)
|
||
|
fp32: bool = InputField(
|
||
|
default=DEFAULT_PRECISION == "float32",
|
||
|
description=FieldDescriptions.fp32,
|
||
|
ui_order=4,
|
||
|
)
|
||
|
|
||
|
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||
|
if mask_image.mode != "L":
|
||
|
mask_image = mask_image.convert("L")
|
||
|
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||
|
if mask_tensor.dim() == 3:
|
||
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||
|
# if shape is not None:
|
||
|
# mask_tensor = tv_resize(mask_tensor, shape, T.InterpolationMode.BILINEAR)
|
||
|
return mask_tensor
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||
|
if self.image is not None:
|
||
|
image = context.images.get_pil(self.image.image_name)
|
||
|
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||
|
if image_tensor.dim() == 3:
|
||
|
image_tensor = image_tensor.unsqueeze(0)
|
||
|
else:
|
||
|
image_tensor = None
|
||
|
|
||
|
mask = self.prep_mask_tensor(
|
||
|
context.images.get_pil(self.mask.image_name),
|
||
|
)
|
||
|
|
||
|
if image_tensor is not None:
|
||
|
vae_info = context.models.load(self.vae.vae)
|
||
|
|
||
|
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||
|
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||
|
# TODO:
|
||
|
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||
|
|
||
|
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||
|
else:
|
||
|
masked_latents_name = None
|
||
|
|
||
|
mask_name = context.tensors.save(tensor=mask)
|
||
|
|
||
|
return DenoiseMaskOutput.build(
|
||
|
mask_name=mask_name,
|
||
|
masked_latents_name=masked_latents_name,
|
||
|
gradient=False,
|
||
|
)
|