InvokeAI/invokeai/app/invocations/infill.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

221 lines
7.2 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
2023-05-05 05:16:26 +00:00
2023-07-03 16:17:45 +00:00
from typing import Literal, Optional, get_args
2023-05-05 05:16:26 +00:00
import numpy as np
import math
from PIL import Image, ImageOps
from invokeai.app.invocations.primitives import ImageField, ImageOutput, ColorField
2023-05-05 05:16:26 +00:00
2023-05-06 09:06:39 +00:00
from invokeai.app.util.misc import SEED_MAX, get_random_seed
2023-05-05 05:16:26 +00:00
from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags
2023-05-05 05:16:26 +00:00
def infill_methods() -> list[str]:
methods = [
"tile",
"solid",
]
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
2023-07-03 16:17:45 +00:00
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
2023-05-05 05:16:26 +00:00
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
2023-08-17 22:45:25 +00:00
replace_count = (tiles_mask is False).sum()
2023-05-05 05:16:26 +00:00
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
@title("Solid Color Infill")
@tags("image", "inpaint")
2023-05-06 09:06:39 +00:00
class InfillColorInvocation(BaseInvocation):
2023-05-06 09:36:51 +00:00
"""Infills transparent areas of an image with a solid color"""
2023-05-05 05:16:26 +00:00
2023-05-06 09:06:39 +00:00
type: Literal["infill_rgba"] = "infill_rgba"
# Inputs
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
2023-05-05 05:16:26 +00:00
default=ColorField(r=127, g=127, b=127, a=255),
2023-05-06 09:06:39 +00:00
description="The color to use to infill",
2023-05-05 05:16:26 +00:00
)
2023-05-06 09:06:39 +00:00
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
2023-05-06 09:06:39 +00:00
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
2023-05-06 09:06:39 +00:00
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
2023-05-06 09:06:39 +00:00
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
2023-05-06 09:06:39 +00:00
)
@title("Tile Infill")
@tags("image", "inpaint")
2023-05-06 09:06:39 +00:00
class InfillTileInvocation(BaseInvocation):
"""Infills transparent areas of an image with tiles of the image"""
type: Literal["infill_tile"] = "infill_tile"
# Input
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
2023-05-06 09:06:39 +00:00
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
default_factory=get_random_seed,
2023-05-05 05:16:26 +00:00
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
2023-05-05 05:16:26 +00:00
2023-05-06 09:06:39 +00:00
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
2023-05-06 09:06:39 +00:00
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
2023-05-06 09:06:39 +00:00
)
@title("PatchMatch Infill")
@tags("image", "inpaint")
2023-05-06 09:06:39 +00:00
class InfillPatchMatchInvocation(BaseInvocation):
2023-05-06 09:36:51 +00:00
"""Infills transparent areas of an image using the PatchMatch algorithm"""
2023-05-06 09:06:39 +00:00
type: Literal["infill_patchmatch"] = "infill_patchmatch"
# Inputs
image: ImageField = InputField(description="The image to infill")
2023-07-18 14:26:45 +00:00
2023-05-06 09:06:39 +00:00
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
2023-05-06 09:06:39 +00:00
if PatchMatch.patchmatch_available():
2023-05-05 05:16:26 +00:00
infilled = infill_patchmatch(image.copy())
else:
2023-05-06 09:06:39 +00:00
raise ValueError("PatchMatch is not available on this system")
2023-05-05 05:16:26 +00:00
image_dto = context.services.images.create(
image=infilled,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
2023-05-05 05:16:26 +00:00
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
2023-05-05 05:16:26 +00:00
)