mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
a1773197e9
- remove `image_origin` from most places where we interact with images - consolidate image file storage into a single `images/` dir Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads. It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images. The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful. So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
|
|
|
from typing import Literal, Union, get_args
|
|
|
|
import numpy as np
|
|
import math
|
|
from PIL import Image, ImageOps
|
|
from pydantic import Field
|
|
|
|
from invokeai.app.invocations.image import ImageOutput
|
|
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
|
|
|
from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
InvocationContext,
|
|
)
|
|
|
|
|
|
def infill_methods() -> list[str]:
|
|
methods = [
|
|
"tile",
|
|
"solid",
|
|
]
|
|
if PatchMatch.patchmatch_available():
|
|
methods.insert(0, "patchmatch")
|
|
return methods
|
|
|
|
|
|
INFILL_METHODS = Literal[tuple(infill_methods())]
|
|
DEFAULT_INFILL_METHOD = (
|
|
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
|
)
|
|
|
|
|
|
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
|
if im.mode != "RGBA":
|
|
return im
|
|
|
|
# Skip patchmatch if patchmatch isn't available
|
|
if not PatchMatch.patchmatch_available():
|
|
return im
|
|
|
|
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
|
im_patched_np = PatchMatch.inpaint(
|
|
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
|
|
)
|
|
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
|
return im_patched
|
|
|
|
|
|
def get_tile_images(image: np.ndarray, width=8, height=8):
|
|
_nrows, _ncols, depth = image.shape
|
|
_strides = image.strides
|
|
|
|
nrows, _m = divmod(_nrows, height)
|
|
ncols, _n = divmod(_ncols, width)
|
|
if _m != 0 or _n != 0:
|
|
return None
|
|
|
|
return np.lib.stride_tricks.as_strided(
|
|
np.ravel(image),
|
|
shape=(nrows, ncols, height, width, depth),
|
|
strides=(height * _strides[0], width * _strides[1], *_strides),
|
|
writeable=False,
|
|
)
|
|
|
|
|
|
def tile_fill_missing(
|
|
im: Image.Image, tile_size: int = 16, seed: Union[int, None] = None
|
|
) -> Image.Image:
|
|
# Only fill if there's an alpha layer
|
|
if im.mode != "RGBA":
|
|
return im
|
|
|
|
a = np.asarray(im, dtype=np.uint8)
|
|
|
|
tile_size_tuple = (tile_size, tile_size)
|
|
|
|
# Get the image as tiles of a specified size
|
|
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
|
|
|
# Get the mask as tiles
|
|
tiles_mask = tiles[:, :, :, :, 3]
|
|
|
|
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
|
tmask_shape = tiles_mask.shape
|
|
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
|
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
|
tiles_mask = tiles_mask > 0
|
|
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
|
|
|
# Get RGB tiles in single array and filter by the mask
|
|
tshape = tiles.shape
|
|
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
|
filtered_tiles = tiles_all[tiles_mask]
|
|
|
|
if len(filtered_tiles) == 0:
|
|
return im
|
|
|
|
# Find all invalid tiles and replace with a random valid tile
|
|
replace_count = (tiles_mask == False).sum()
|
|
rng = np.random.default_rng(seed=seed)
|
|
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
|
|
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
|
|
]
|
|
|
|
# Convert back to an image
|
|
tiles_all = tiles_all.reshape(tshape)
|
|
tiles_all = tiles_all.swapaxes(1, 2)
|
|
st = tiles_all.reshape(
|
|
(
|
|
math.prod(tiles_all.shape[0:2]),
|
|
math.prod(tiles_all.shape[2:4]),
|
|
tiles_all.shape[4],
|
|
)
|
|
)
|
|
si = Image.fromarray(st, mode="RGBA")
|
|
|
|
return si
|
|
|
|
|
|
class InfillColorInvocation(BaseInvocation):
|
|
"""Infills transparent areas of an image with a solid color"""
|
|
|
|
type: Literal["infill_rgba"] = "infill_rgba"
|
|
image: Union[ImageField, None] = Field(
|
|
default=None, description="The image to infill"
|
|
)
|
|
color: ColorField = Field(
|
|
default=ColorField(r=127, g=127, b=127, a=255),
|
|
description="The color to use to infill",
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
|
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
|
|
|
infilled.paste(image, (0, 0), image.split()[-1])
|
|
|
|
image_dto = context.services.images.create(
|
|
image=infilled,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class InfillTileInvocation(BaseInvocation):
|
|
"""Infills transparent areas of an image with tiles of the image"""
|
|
|
|
type: Literal["infill_tile"] = "infill_tile"
|
|
|
|
image: Union[ImageField, None] = Field(
|
|
default=None, description="The image to infill"
|
|
)
|
|
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
|
|
seed: int = Field(
|
|
ge=0,
|
|
le=SEED_MAX,
|
|
description="The seed to use for tile generation (omit for random)",
|
|
default_factory=get_random_seed,
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
infilled = tile_fill_missing(
|
|
image.copy(), seed=self.seed, tile_size=self.tile_size
|
|
)
|
|
infilled.paste(image, (0, 0), image.split()[-1])
|
|
|
|
image_dto = context.services.images.create(
|
|
image=infilled,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class InfillPatchMatchInvocation(BaseInvocation):
|
|
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
|
|
|
type: Literal["infill_patchmatch"] = "infill_patchmatch"
|
|
|
|
image: Union[ImageField, None] = Field(
|
|
default=None, description="The image to infill"
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
if PatchMatch.patchmatch_available():
|
|
infilled = infill_patchmatch(image.copy())
|
|
else:
|
|
raise ValueError("PatchMatch is not available on this system")
|
|
|
|
image_dto = context.services.images.create(
|
|
image=infilled,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|