mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Initial Infill Methods Refactor
This commit is contained in:
parent
3659219f46
commit
32a6b758cd
@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
from invokeai.backend.util.logging import logging
|
||||
from invokeai.version import __version__
|
||||
@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
infill_methods = ["tile", "lama", "cv2", "color", "mosaic"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
infill_methods.append("patchmatch")
|
||||
|
||||
|
@ -1,154 +1,91 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
|
||||
from abc import abstractmethod
|
||||
from typing import Literal, get_args
|
||||
|
||||
import math
|
||||
from typing import Literal, Optional, get_args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.infill_methods.lama import LaMA
|
||||
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
|
||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
|
||||
from invokeai.backend.image_util.infill_methods.tile import infill_tile
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
def infill_methods() -> list[str]:
|
||||
methods = ["tile", "solid", "lama", "cv2"]
|
||||
|
||||
def get_infill_methods():
|
||||
methods = Literal["tile", "color", "lama", "cv2", "mosaic"]
|
||||
if PatchMatch.patchmatch_available():
|
||||
methods.insert(0, "patchmatch")
|
||||
methods = Literal["patchmatch", "tile", "color", "lama", "cv2", "mosaic"]
|
||||
return methods
|
||||
|
||||
|
||||
INFILL_METHODS = Literal[tuple(infill_methods())]
|
||||
INFILL_METHODS = get_infill_methods()
|
||||
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
|
||||
|
||||
|
||||
def infill_lama(im: Image.Image) -> Image.Image:
|
||||
lama = LaMA()
|
||||
return lama(im)
|
||||
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Base class for invocations that preprocess images for Infilling"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
def infill_patchmatch(im: Image.Image) -> Image.Image:
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
@abstractmethod
|
||||
def infill(self, image: Image.Image) -> Image.Image:
|
||||
"""Abstract to perform various infilling techniques"""
|
||||
return image
|
||||
|
||||
# Skip patchmatch if patchmatch isn't available
|
||||
if not PatchMatch.patchmatch_available():
|
||||
return im
|
||||
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
|
||||
"""Process the image to have an alpha channel before being infilled"""
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
has_alpha = True if image.mode == "RGBA" else False
|
||||
return image, has_alpha
|
||||
|
||||
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
|
||||
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
|
||||
im_patched = Image.fromarray(im_patched_np, mode="RGB")
|
||||
return im_patched
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Retrieve and process image to be infilled
|
||||
input_image, has_alpha = self.load_image(context)
|
||||
|
||||
# If the input image has no alpha channel, return it
|
||||
if has_alpha is False:
|
||||
return ImageOutput.build(context.images.get_dto(self.image.image_name))
|
||||
|
||||
def infill_cv2(im: Image.Image) -> Image.Image:
|
||||
return cv2_inpaint(im)
|
||||
# Perform Infill action
|
||||
infilled_image = self.infill(input_image)
|
||||
|
||||
# Create ImageDTO for Infilled Image
|
||||
infilled_image_dto = context.images.save(image=infilled_image)
|
||||
|
||||
def get_tile_images(image: np.ndarray, width=8, height=8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
||||
# Return Infilled Image
|
||||
return ImageOutput.build(infilled_image_dto)
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillColorInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
color: ColorField = InputField(
|
||||
default=ColorField(r=127, g=127, b=127, a=255),
|
||||
description="The color to use to infill",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return infilled
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillTileInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
@ -157,92 +94,76 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
def infill(self, image: Image.Image):
|
||||
infilled = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return infilled
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
def infill(self, image: Image.Image):
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
infill_image = image.copy()
|
||||
width = int(image.width / self.downscale)
|
||||
height = int(image.height / self.downscale)
|
||||
infill_image = infill_image.resize(
|
||||
|
||||
infilled = image.resize(
|
||||
(width, height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
if PatchMatch.patchmatch_available():
|
||||
infilled = infill_patchmatch(infill_image)
|
||||
else:
|
||||
raise ValueError("PatchMatch is not available on this system")
|
||||
|
||||
infilled = infill_patchmatch(image)
|
||||
infilled = infilled.resize(
|
||||
(image.width, image.height),
|
||||
resample=resample_mode,
|
||||
)
|
||||
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return infilled
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class LaMaInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Downloads the LaMa model if it doesn't already exist
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
|
||||
)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
def infill(self, image: Image.Image):
|
||||
lama = LaMA()
|
||||
return lama(image)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class CV2InfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
def infill(self, image: Image.Image):
|
||||
return cv2_inpaint(image)
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
|
||||
)
|
||||
class MosaicInfillInvocation(InfillImageProcessorInvocation):
|
||||
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_width: int = InputField(default=64, description="Width of the tile")
|
||||
tile_height: int = InputField(default=64, description="Height of the tile")
|
||||
min_color: ColorField = InputField(
|
||||
default=ColorField(r=0, g=0, b=0, a=255),
|
||||
description="The min threshold for color",
|
||||
)
|
||||
max_color: ColorField = InputField(
|
||||
default=ColorField(r=255, g=255, b=255, a=255),
|
||||
description="The max threshold for color",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
def infill(self, image: Image.Image):
|
||||
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())
|
||||
|
@ -2,7 +2,7 @@
|
||||
Initialization file for invokeai.backend.image_util methods.
|
||||
"""
|
||||
|
||||
from .patchmatch import PatchMatch # noqa: F401
|
||||
from .infill_methods.patchmatch import PatchMatch # noqa: F401
|
||||
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
|
||||
from .seamless import configure_model_padding # noqa: F401
|
||||
from .util import InitImageResizer, make_grid # noqa: F401
|
||||
|
@ -7,6 +7,7 @@ from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
|
||||
@ -30,6 +31,14 @@ class LaMA:
|
||||
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
|
||||
device = choose_torch_device()
|
||||
model_location = get_config().models_path / "core/misc/lama/lama.pt"
|
||||
|
||||
if not model_location.exists():
|
||||
download_with_progress_bar(
|
||||
name="LaMa Inpainting Model",
|
||||
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
||||
dest_path=model_location,
|
||||
)
|
||||
|
||||
model = load_jit_model(model_location, device)
|
||||
|
||||
image = np.asarray(input_image.convert("RGB"))
|
56
invokeai/backend/image_util/infill_methods/mosaic.py
Normal file
56
invokeai/backend/image_util/infill_methods/mosaic.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def infill_mosaic(
|
||||
image: Image.Image,
|
||||
tile_shape: Tuple[int, int] = (64, 16),
|
||||
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
||||
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
|
||||
) -> Image.Image:
|
||||
"""
|
||||
image:PIL - A PIL Image
|
||||
tile_shape: Tuple[int,int] - Tile width & Tile Height
|
||||
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
|
||||
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
|
||||
"""
|
||||
|
||||
np_image = np.array(image) # Convert image to np array
|
||||
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
|
||||
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
|
||||
|
||||
# Create color tiles to paste in the empty areas of the image
|
||||
tile_width, tile_height = tile_shape
|
||||
|
||||
# Clip the range of colors in the image to a particular spectrum only
|
||||
r_min, g_min, b_min, _ = min_color
|
||||
r_max, g_max, b_max, _ = max_color
|
||||
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
|
||||
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
|
||||
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
|
||||
|
||||
tiles = []
|
||||
for _ in range(256):
|
||||
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
|
||||
|
||||
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
|
||||
tile[:, :] = color
|
||||
tiles.append(tile)
|
||||
|
||||
# Fill the transparent area with tiles
|
||||
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
|
||||
for x in range(image.width):
|
||||
for y in range(image.height):
|
||||
tile = tiles[np.random.randint(len(tiles))]
|
||||
filled_image[
|
||||
y - (y % tile_height) : y - (y % tile_height) + tile_height,
|
||||
x - (x % tile_width) : x - (x % tile_width) + tile_width,
|
||||
] = tile
|
||||
|
||||
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
|
||||
image = Image.composite(
|
||||
image, filled_image, image.split()[-1]
|
||||
) # Composite the original image on top of the filled tiles
|
||||
return image
|
67
invokeai/backend/image_util/infill_methods/patchmatch.py
Normal file
67
invokeai/backend/image_util/infill_methods/patchmatch.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(cls):
|
||||
if cls.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
cls.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
cls.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(cls) -> bool:
|
||||
cls._load_patch_match()
|
||||
if not cls.patch_match:
|
||||
return False
|
||||
return cls.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(cls, image: Image.Image) -> Image.Image:
|
||||
if cls.patch_match is None or not cls.patchmatch_available():
|
||||
return image
|
||||
|
||||
np_image = np.array(image)
|
||||
mask = 255 - np_image[:, :, 3]
|
||||
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
|
||||
return Image.fromarray(infilled, mode="RGB")
|
||||
|
||||
|
||||
def infill_patchmatch(image: Image.Image) -> Image.Image:
|
||||
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
|
||||
|
||||
if not IS_PATCHMATCH_AVAILABLE:
|
||||
logger.warning("PatchMatch is not available on this system")
|
||||
return image
|
||||
|
||||
return PatchMatch.inpaint(image)
|
72
invokeai/backend/image_util/infill_methods/tile.py
Normal file
72
invokeai/backend/image_util/infill_methods/tile.py
Normal file
@ -0,0 +1,72 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_tile_images(image: np.ndarray, width: int = 8, height: int = 8):
|
||||
_nrows, _ncols, depth = image.shape
|
||||
_strides = image.strides
|
||||
|
||||
nrows, _m = divmod(_nrows, height)
|
||||
ncols, _n = divmod(_ncols, width)
|
||||
if _m != 0 or _n != 0:
|
||||
return None
|
||||
|
||||
return np.lib.stride_tricks.as_strided(
|
||||
np.ravel(image),
|
||||
shape=(nrows, ncols, height, width, depth),
|
||||
strides=(height * _strides[0], width * _strides[1], *_strides),
|
||||
writeable=False,
|
||||
)
|
||||
|
||||
|
||||
def infill_tile(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
|
||||
# Only fill if there's an alpha layer
|
||||
if im.mode != "RGBA":
|
||||
return im
|
||||
|
||||
a = np.asarray(im, dtype=np.uint8)
|
||||
|
||||
tile_size_tuple = (tile_size, tile_size)
|
||||
|
||||
# Get the image as tiles of a specified size
|
||||
tiles = get_tile_images(a, *tile_size_tuple).copy()
|
||||
|
||||
# Get the mask as tiles
|
||||
tiles_mask = tiles[:, :, :, :, 3]
|
||||
|
||||
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
|
||||
tmask_shape = tiles_mask.shape
|
||||
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
|
||||
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
|
||||
tiles_mask = tiles_mask > 0
|
||||
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
|
||||
|
||||
# Get RGB tiles in single array and filter by the mask
|
||||
tshape = tiles.shape
|
||||
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
|
||||
filtered_tiles = tiles_all[tiles_mask]
|
||||
|
||||
if len(filtered_tiles) == 0:
|
||||
return im
|
||||
|
||||
# Find all invalid tiles and replace with a random valid tile
|
||||
replace_count = (tiles_mask == False).sum() # noqa: E712
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
|
||||
|
||||
# Convert back to an image
|
||||
tiles_all = tiles_all.reshape(tshape)
|
||||
tiles_all = tiles_all.swapaxes(1, 2)
|
||||
st = tiles_all.reshape(
|
||||
(
|
||||
math.prod(tiles_all.shape[0:2]),
|
||||
math.prod(tiles_all.shape[2:4]),
|
||||
tiles_all.shape[4],
|
||||
)
|
||||
)
|
||||
si = Image.fromarray(st, mode="RGBA")
|
||||
|
||||
return si
|
@ -1,49 +0,0 @@
|
||||
"""
|
||||
This module defines a singleton object, "patchmatch" that
|
||||
wraps the actual patchmatch object. It respects the global
|
||||
"try_patchmatch" attribute, so that patchmatch loading can
|
||||
be suppressed or deferred
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
"""
|
||||
Thin class wrapper around the patchmatch function.
|
||||
"""
|
||||
|
||||
patch_match = None
|
||||
tried_load: bool = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _load_patch_match(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
if get_config().patchmatch:
|
||||
from patchmatch import patch_match as pm
|
||||
|
||||
if pm.patchmatch_available:
|
||||
logger.info("Patchmatch initialized")
|
||||
else:
|
||||
logger.info("Patchmatch not loaded (nonfatal)")
|
||||
self.patch_match = pm
|
||||
else:
|
||||
logger.info("Patchmatch loading disabled")
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def patchmatch_available(self) -> bool:
|
||||
self._load_patch_match()
|
||||
return self.patch_match and self.patch_match.patchmatch_available
|
||||
|
||||
@classmethod
|
||||
def inpaint(self, *args, **kwargs) -> np.ndarray:
|
||||
if self.patchmatch_available():
|
||||
return self.patch_match.inpaint(*args, **kwargs)
|
@ -356,6 +356,22 @@ export const buildCanvasOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'mosaic') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_mosaic',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
@ -365,6 +365,22 @@ export const buildCanvasSDXLOutpaintGraph = async (
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'mosaic') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_mosaic',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
if (infillMethod === 'color') {
|
||||
graph.nodes[INPAINT_INFILL] = {
|
||||
type: 'infill_rgba',
|
||||
id: INPAINT_INFILL,
|
||||
is_intermediate,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Scale Before Processing
|
||||
if (isUsingScaledDimensions) {
|
||||
const scaledWidth: number = scaledBoundingBoxDimensions.width;
|
||||
|
Loading…
Reference in New Issue
Block a user