Merge branch 'main' into lstein/bugfix/vram-oom-errors

This commit is contained in:
Lincoln Stein 2024-04-04 22:34:44 -04:00 committed by GitHub
commit 0c9332835d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 796 additions and 250 deletions

View File

@ -12,7 +12,7 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging
from invokeai.version import __version__
@ -100,7 +100,7 @@ async def get_app_deps() -> AppDependencyVersions:
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
async def get_config() -> AppConfig:
infill_methods = ["tile", "lama", "cv2"]
infill_methods = ["tile", "lama", "cv2", "color"] # TODO: add mosaic back
if PatchMatch.patchmatch_available():
infill_methods.append("patchmatch")

View File

@ -1,154 +1,91 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team
from abc import abstractmethod
from typing import Literal, get_args
import math
from typing import Literal, Optional, get_args
import numpy as np
from PIL import Image, ImageOps
from PIL import Image
from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.infill_methods.lama import LaMA
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
from invokeai.backend.image_util.infill_methods.tile import infill_tile
from invokeai.backend.util.logging import InvokeAILogger
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
logger = InvokeAILogger.get_logger()
def infill_methods() -> list[str]:
methods = ["tile", "solid", "lama", "cv2"]
def get_infill_methods():
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
if PatchMatch.patchmatch_available():
methods.insert(0, "patchmatch")
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
return methods
INFILL_METHODS = Literal[tuple(infill_methods())]
INFILL_METHODS = get_infill_methods()
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
def infill_lama(im: Image.Image) -> Image.Image:
lama = LaMA()
return lama(im)
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Base class for invocations that preprocess images for Infilling"""
image: ImageField = InputField(description="The image to process")
def infill_patchmatch(im: Image.Image) -> Image.Image:
if im.mode != "RGBA":
return im
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
"""Infill the image with the specified method"""
pass
# Skip patchmatch if patchmatch isn't available
if not PatchMatch.patchmatch_available():
return im
def load_image(self, context: InvocationContext) -> tuple[Image.Image, bool]:
"""Process the image to have an alpha channel before being infilled"""
image = context.images.get_pil(self.image.image_name)
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
def invoke(self, context: InvocationContext) -> ImageOutput:
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image(context)
# If the input image has no alpha channel, return it
if has_alpha is False:
return ImageOutput.build(context.images.get_dto(self.image.image_name))
def infill_cv2(im: Image.Image) -> Image.Image:
return cv2_inpaint(im)
# Perform Infill action
infilled_image = self.infill(input_image)
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
def get_tile_images(image: np.ndarray, width=8, height=8):
_nrows, _ncols, depth = image.shape
_strides = image.strides
nrows, _m = divmod(_nrows, height)
ncols, _n = divmod(_ncols, width)
if _m != 0 or _n != 0:
return None
return np.lib.stride_tricks.as_strided(
np.ravel(image),
shape=(nrows, ncols, height, width, depth),
strides=(height * _strides[0], width * _strides[1], *_strides),
writeable=False,
)
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
a = np.asarray(im, dtype=np.uint8)
tile_size_tuple = (tile_size, tile_size)
# Get the image as tiles of a specified size
tiles = get_tile_images(a, *tile_size_tuple).copy()
# Get the mask as tiles
tiles_mask = tiles[:, :, :, :, 3]
# Find any mask tiles with any fully transparent pixels (we will be replacing these later)
tmask_shape = tiles_mask.shape
tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
tiles_mask = tiles_mask > 0
tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)
# Get RGB tiles in single array and filter by the mask
tshape = tiles.shape
tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
filtered_tiles = tiles_all[tiles_mask]
if len(filtered_tiles) == 0:
return im
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum() # noqa: E712
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
tiles_all = tiles_all.swapaxes(1, 2)
st = tiles_all.reshape(
(
math.prod(tiles_all.shape[0:2]),
math.prod(tiles_all.shape[2:4]),
tiles_all.shape[4],
)
)
si = Image.fromarray(st, mode="RGBA")
return si
# Return Infilled Image
return ImageOutput.build(infilled_image_dto)
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillColorInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with a solid color"""
image: ImageField = InputField(description="The image to infill")
color: ColorField = InputField(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
def infill(self, image: Image.Image):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
return infilled
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillTileInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with tiles of the image"""
image: ImageField = InputField(description="The image to infill")
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
default=0,
@ -157,92 +94,74 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The seed to use for tile generation (omit for random)",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
)
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the PatchMatch algorithm"""
image: ImageField = InputField(description="The image to infill")
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name).convert("RGBA")
def infill(self, image: Image.Image):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
infill_image = image.copy()
width = int(image.width / self.downscale)
height = int(image.height / self.downscale)
infill_image = infill_image.resize(
infilled = image.resize(
(width, height),
resample=resample_mode,
)
if PatchMatch.patchmatch_available():
infilled = infill_patchmatch(infill_image)
else:
raise ValueError("PatchMatch is not available on this system")
infilled = infill_patchmatch(image)
infilled = infilled.resize(
(image.width, image.height),
resample=resample_mode,
)
infilled.paste(image, (0, 0), mask=image.split()[-1])
# image.paste(infilled, (0, 0), mask=image.split()[-1])
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
return infilled
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
image: ImageField = InputField(description="The image to infill")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Downloads the LaMa model if it doesn't already exist
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=context.config.get().models_path / "core/misc/lama/lama.pt",
)
infilled = infill_lama(image.copy())
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
lama = LaMA()
return lama(image)
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
return cv2_inpaint(image)
# @invocation(
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
# )
class MosaicInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
image: ImageField = InputField(description="The image to infill")
tile_width: int = InputField(default=64, description="Width of the tile")
tile_height: int = InputField(default=64, description="Height of the tile")
min_color: ColorField = InputField(
default=ColorField(r=0, g=0, b=0, a=255),
description="The min threshold for color",
)
max_color: ColorField = InputField(
default=ColorField(r=255, g=255, b=255, a=255),
description="The max threshold for color",
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
infilled = infill_cv2(image.copy())
image_dto = context.images.save(image=infilled)
return ImageOutput.build(image_dto)
def infill(self, image: Image.Image):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@ -65,7 +65,7 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1,
ui_type=UIType.IPAdapterModel,
)
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
clip_vision_model: Literal["ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="auto",
ui_order=2,
@ -96,14 +96,9 @@ class IPAdapterInvocation(BaseInvocation):
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
if self.clip_vision_model == "auto":
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
raise RuntimeError(
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
)
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]

View File

@ -1254,7 +1254,7 @@ class IdealSizeInvocation(BaseInvocation):
return tuple((x - x % multiple_of) for x in args)
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
unet_config = context.models.get_config(**self.unet.unet.model_dump())
unet_config = context.models.get_config(self.unet.unet.key)
aspect = self.width / self.height
dimension: float = 512
if unet_config.base == BaseModelType.StableDiffusion2:

View File

@ -2,7 +2,7 @@
Initialization file for invokeai.backend.image_util methods.
"""
from .patchmatch import PatchMatch # noqa: F401
from .infill_methods.patchmatch import PatchMatch # noqa: F401
from .pngwriter import PngWriter, PromptFormatter, retrieve_metadata, write_metadata # noqa: F401
from .seamless import configure_model_padding # noqa: F401
from .util import InitImageResizer, make_grid # noqa: F401

View File

@ -7,6 +7,7 @@ from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import choose_torch_device
@ -30,6 +31,14 @@ class LaMA:
def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
device = choose_torch_device()
model_location = get_config().models_path / "core/misc/lama/lama.pt"
if not model_location.exists():
download_with_progress_bar(
name="LaMa Inpainting Model",
url="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
dest_path=model_location,
)
model = load_jit_model(model_location, device)
image = np.asarray(input_image.convert("RGB"))

View File

@ -0,0 +1,60 @@
from typing import Tuple
import numpy as np
from PIL import Image
def infill_mosaic(
image: Image.Image,
tile_shape: Tuple[int, int] = (64, 64),
min_color: Tuple[int, int, int, int] = (0, 0, 0, 0),
max_color: Tuple[int, int, int, int] = (255, 255, 255, 0),
) -> Image.Image:
"""
image:PIL - A PIL Image
tile_shape: Tuple[int,int] - Tile width & Tile Height
min_color: Tuple[int,int,int] - RGB values for the lowest color to clip to (0-255)
max_color: Tuple[int,int,int] - RGB values for the highest color to clip to (0-255)
"""
np_image = np.array(image) # Convert image to np array
alpha = np_image[:, :, 3] # Get the mask from the alpha channel of the image
non_transparent_pixels = np_image[alpha != 0, :3] # List of non-transparent pixels
# Create color tiles to paste in the empty areas of the image
tile_width, tile_height = tile_shape
# Clip the range of colors in the image to a particular spectrum only
r_min, g_min, b_min, _ = min_color
r_max, g_max, b_max, _ = max_color
non_transparent_pixels[:, 0] = np.clip(non_transparent_pixels[:, 0], r_min, r_max)
non_transparent_pixels[:, 1] = np.clip(non_transparent_pixels[:, 1], g_min, g_max)
non_transparent_pixels[:, 2] = np.clip(non_transparent_pixels[:, 2], b_min, b_max)
tiles = []
for _ in range(256):
color = non_transparent_pixels[np.random.randint(len(non_transparent_pixels))]
tile = np.zeros((tile_height, tile_width, 3), dtype=np.uint8)
tile[:, :] = color
tiles.append(tile)
# Fill the transparent area with tiles
filled_image = np.zeros((image.height, image.width, 3), dtype=np.uint8)
for x in range(image.width):
for y in range(image.height):
tile = tiles[np.random.randint(len(tiles))]
try:
filled_image[
y - (y % tile_height) : y - (y % tile_height) + tile_height,
x - (x % tile_width) : x - (x % tile_width) + tile_width,
] = tile
except ValueError:
# Need to handle edge cases - literally
pass
filled_image = Image.fromarray(filled_image) # Convert the filled tiles image to PIL
image = Image.composite(
image, filled_image, image.split()[-1]
) # Composite the original image on top of the filled tiles
return image

View File

@ -0,0 +1,67 @@
"""
This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
"""
patch_match = None
tried_load: bool = False
def __init__(self):
super().__init__()
@classmethod
def _load_patch_match(cls):
if cls.tried_load:
return
if get_config().patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
logger.info("Patchmatch initialized")
cls.patch_match = pm
else:
logger.info("Patchmatch not loaded (nonfatal)")
else:
logger.info("Patchmatch loading disabled")
cls.tried_load = True
@classmethod
def patchmatch_available(cls) -> bool:
cls._load_patch_match()
if not cls.patch_match:
return False
return cls.patch_match.patchmatch_available
@classmethod
def inpaint(cls, image: Image.Image) -> Image.Image:
if cls.patch_match is None or not cls.patchmatch_available():
return image
np_image = np.array(image)
mask = 255 - np_image[:, :, 3]
infilled = cls.patch_match.inpaint(np_image[:, :, :3], mask, patch_size=3)
return Image.fromarray(infilled, mode="RGB")
def infill_patchmatch(image: Image.Image) -> Image.Image:
IS_PATCHMATCH_AVAILABLE = PatchMatch.patchmatch_available()
if not IS_PATCHMATCH_AVAILABLE:
logger.warning("PatchMatch is not available on this system")
return image
return PatchMatch.inpaint(image)

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 49 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -0,0 +1,95 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"Smoke test for the tile infill\"\"\"\n",
"\n",
"from pathlib import Path\n",
"from typing import Optional\n",
"from PIL import Image\n",
"from invokeai.backend.image_util.infill_methods.tile import infill_tile\n",
"\n",
"images: list[tuple[str, Image.Image]] = []\n",
"\n",
"for i in sorted(Path(\"./test_images/\").glob(\"*.webp\")):\n",
" images.append((i.name, Image.open(i)))\n",
" images.append((i.name, Image.open(i).transpose(Image.FLIP_LEFT_RIGHT)))\n",
" images.append((i.name, Image.open(i).transpose(Image.FLIP_TOP_BOTTOM)))\n",
" images.append((i.name, Image.open(i).resize((512, 512))))\n",
" images.append((i.name, Image.open(i).resize((1234, 461))))\n",
"\n",
"outputs: list[tuple[str, Image.Image, Image.Image, Optional[Image.Image]]] = []\n",
"\n",
"for name, image in images:\n",
" try:\n",
" output = infill_tile(image, seed=0, tile_size=32)\n",
" outputs.append((name, image, output.infilled, output.tile_image))\n",
" except ValueError as e:\n",
" print(f\"Skipping image {name}: {e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Display the images in jupyter notebook\n",
"import matplotlib.pyplot as plt\n",
"from PIL import ImageOps\n",
"\n",
"fig, axes = plt.subplots(len(outputs), 3, figsize=(10, 3 * len(outputs)))\n",
"plt.subplots_adjust(hspace=0)\n",
"\n",
"for i, (name, original, infilled, tile_image) in enumerate(outputs):\n",
" # Add a border to each image, helps to see the edges\n",
" size = original.size\n",
" original = ImageOps.expand(original, border=5, fill=\"red\")\n",
" filled = ImageOps.expand(infilled, border=5, fill=\"red\")\n",
" if tile_image:\n",
" tile_image = ImageOps.expand(tile_image, border=5, fill=\"red\")\n",
"\n",
" axes[i, 0].imshow(original)\n",
" axes[i, 0].axis(\"off\")\n",
" axes[i, 0].set_title(f\"Original ({name} - {size})\")\n",
"\n",
" if tile_image:\n",
" axes[i, 1].imshow(tile_image)\n",
" axes[i, 1].axis(\"off\")\n",
" axes[i, 1].set_title(\"Tile Image\")\n",
" else:\n",
" axes[i, 1].axis(\"off\")\n",
" axes[i, 1].set_title(\"NO TILES GENERATED (NO TRANSPARENCY)\")\n",
"\n",
" axes[i, 2].imshow(filled)\n",
" axes[i, 2].axis(\"off\")\n",
" axes[i, 2].set_title(\"Filled\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".invokeai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,122 @@
from dataclasses import dataclass
from typing import Optional
import numpy as np
from PIL import Image
def create_tile_pool(img_array: np.ndarray, tile_size: tuple[int, int]) -> list[np.ndarray]:
"""
Create a pool of tiles from non-transparent areas of the image by systematically walking through the image.
Args:
img_array: numpy array of the image.
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
Returns:
A list of numpy arrays, each representing a tile.
"""
tiles: list[np.ndarray] = []
rows, cols = img_array.shape[:2]
tile_width, tile_height = tile_size
for y in range(0, rows - tile_height + 1, tile_height):
for x in range(0, cols - tile_width + 1, tile_width):
tile = img_array[y : y + tile_height, x : x + tile_width]
# Check if the image has an alpha channel and the tile is completely opaque
if img_array.shape[2] == 4 and np.all(tile[:, :, 3] == 255):
tiles.append(tile)
elif img_array.shape[2] == 3: # If no alpha channel, append the tile
tiles.append(tile)
if not tiles:
raise ValueError(
"Not enough opaque pixels to generate any tiles. Use a smaller tile size or a different image."
)
return tiles
def create_filled_image(
img_array: np.ndarray, tile_pool: list[np.ndarray], tile_size: tuple[int, int], seed: int
) -> np.ndarray:
"""
Create an image of the same dimensions as the original, filled entirely with tiles from the pool.
Args:
img_array: numpy array of the original image.
tile_pool: A list of numpy arrays, each representing a tile.
tile_size: tuple (tile_width, tile_height) specifying the size of each tile.
Returns:
A numpy array representing the filled image.
"""
rows, cols, _ = img_array.shape
tile_width, tile_height = tile_size
# Prep an empty RGB image
filled_img_array = np.zeros((rows, cols, 3), dtype=img_array.dtype)
# Make the random tile selection reproducible
rng = np.random.default_rng(seed)
for y in range(0, rows, tile_height):
for x in range(0, cols, tile_width):
# Pick a random tile from the pool
tile = tile_pool[rng.integers(len(tile_pool))]
# Calculate the space available (may be less than tile size near the edges)
space_y = min(tile_height, rows - y)
space_x = min(tile_width, cols - x)
# Crop the tile if necessary to fit into the available space
cropped_tile = tile[:space_y, :space_x, :3]
# Fill the available space with the (possibly cropped) tile
filled_img_array[y : y + space_y, x : x + space_x, :3] = cropped_tile
return filled_img_array
@dataclass
class InfillTileOutput:
infilled: Image.Image
tile_image: Optional[Image.Image] = None
def infill_tile(image_to_infill: Image.Image, seed: int, tile_size: int) -> InfillTileOutput:
"""Infills an image with random tiles from the image itself.
If the image is not an RGBA image, it is returned untouched.
Args:
image: The image to infill.
tile_size: The size of the tiles to use for infilling.
Raises:
ValueError: If there are not enough opaque pixels to generate any tiles.
"""
if image_to_infill.mode != "RGBA":
return InfillTileOutput(infilled=image_to_infill)
# Internally, we want a tuple of (tile_width, tile_height). In the future, the tile size can be any rectangle.
_tile_size = (tile_size, tile_size)
np_image = np.array(image_to_infill, dtype=np.uint8)
# Create the pool of tiles that we will use to infill
tile_pool = create_tile_pool(np_image, _tile_size)
# Create an image from the tiles, same size as the original
tile_np_image = create_filled_image(np_image, tile_pool, _tile_size, seed)
# Paste the OG image over the tile image, effectively infilling the area
tile_image = Image.fromarray(tile_np_image, "RGB")
infilled = tile_image.copy()
infilled.paste(image_to_infill, (0, 0), image_to_infill.split()[-1])
# I think we want this to be "RGBA"?
infilled.convert("RGBA")
return InfillTileOutput(infilled=infilled, tile_image=tile_image)

View File

@ -1,49 +0,0 @@
"""
This module defines a singleton object, "patchmatch" that
wraps the actual patchmatch object. It respects the global
"try_patchmatch" attribute, so that patchmatch loading can
be suppressed or deferred
"""
import numpy as np
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
class PatchMatch:
"""
Thin class wrapper around the patchmatch function.
"""
patch_match = None
tried_load: bool = False
def __init__(self):
super().__init__()
@classmethod
def _load_patch_match(self):
if self.tried_load:
return
if get_config().patchmatch:
from patchmatch import patch_match as pm
if pm.patchmatch_available:
logger.info("Patchmatch initialized")
else:
logger.info("Patchmatch not loaded (nonfatal)")
self.patch_match = pm
else:
logger.info("Patchmatch loading disabled")
self.tried_load = True
@classmethod
def patchmatch_available(self) -> bool:
self._load_patch_match()
return self.patch_match and self.patch_match.patchmatch_available
@classmethod
def inpaint(self, *args, **kwargs) -> np.ndarray:
if self.patchmatch_available():
return self.patch_match.inpaint(*args, **kwargs)

View File

@ -684,6 +684,7 @@
"noModelsInstalled": "No Models Installed",
"noModelsInstalledDesc1": "Install models with the",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching Models",
"none": "none",
"path": "Path",
"pathToConfig": "Path To Config",
@ -887,6 +888,11 @@
"imageFit": "Fit Initial Image To Output Size",
"images": "Images",
"infillMethod": "Infill Method",
"infillMosaicTileWidth": "Tile Width",
"infillMosaicTileHeight": "Tile Height",
"infillMosaicMinColor": "Min Color",
"infillMosaicMaxColor": "Max Color",
"infillColorValue": "Fill Color",
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",

View File

@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import type { ModelType } from 'services/api/types';
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'>;
export type FilterableModelType = Exclude<ModelType, 'onnx' | 'clip_vision'> | 'refiner';
type ModelManagerState = {
_version: 1;

View File

@ -1,6 +1,7 @@
import { Flex } from '@invoke-ai/ui-library';
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import type { FilterableModelType } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import {
@ -9,10 +10,11 @@ import {
useIPAdapterModels,
useLoRAModels,
useMainModels,
useRefinerModels,
useT2IAdapterModels,
useVAEModels,
} from 'services/api/hooks/modelsByType';
import type { AnyModelConfig, ModelType } from 'services/api/types';
import type { AnyModelConfig } from 'services/api/types';
import { FetchingModelsLoader } from './FetchingModelsLoader';
import { ModelListWrapper } from './ModelListWrapper';
@ -27,6 +29,12 @@ const ModelList = () => {
[mainModels, searchTerm, filteredModelType]
);
const [refinerModels, { isLoading: isLoadingRefinerModels }] = useRefinerModels();
const filteredRefinerModels = useMemo(
() => modelsFilter(refinerModels, searchTerm, filteredModelType),
[refinerModels, searchTerm, filteredModelType]
);
const [loraModels, { isLoading: isLoadingLoRAModels }] = useLoRAModels();
const filteredLoRAModels = useMemo(
() => modelsFilter(loraModels, searchTerm, filteredModelType),
@ -63,6 +71,28 @@ const ModelList = () => {
[vaeModels, searchTerm, filteredModelType]
);
const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
filteredRefinerModels.length +
filteredLoRAModels.length +
filteredEmbeddingModels.length +
filteredControlNetModels.length +
filteredT2IAdapterModels.length +
filteredIPAdapterModels.length +
filteredVAEModels.length
);
}, [
filteredControlNetModels.length,
filteredEmbeddingModels.length,
filteredIPAdapterModels.length,
filteredLoRAModels.length,
filteredMainModels.length,
filteredRefinerModels.length,
filteredT2IAdapterModels.length,
filteredVAEModels.length,
]);
return (
<ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}>
@ -71,6 +101,11 @@ const ModelList = () => {
{!isLoadingMainModels && filteredMainModels.length > 0 && (
<ModelListWrapper title={t('modelManager.main')} modelList={filteredMainModels} key="main" />
)}
{/* Refiner Model List */}
{isLoadingRefinerModels && <FetchingModelsLoader loadingMessage="Loading Refiner Models..." />}
{!isLoadingRefinerModels && filteredRefinerModels.length > 0 && (
<ModelListWrapper title={t('sdxl.refiner')} modelList={filteredRefinerModels} key="refiner" />
)}
{/* LoRAs List */}
{isLoadingLoRAModels && <FetchingModelsLoader loadingMessage="Loading LoRAs..." />}
{!isLoadingLoRAModels && filteredLoRAModels.length > 0 && (
@ -108,6 +143,11 @@ const ModelList = () => {
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)}
{totalFilteredModels === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>{t('modelManager.noMatchingModels')}</Text>
</Flex>
)}
</Flex>
</ScrollableContent>
);
@ -118,12 +158,24 @@ export default memo(ModelList);
const modelsFilter = <T extends AnyModelConfig>(
data: T[],
nameFilter: string,
filteredModelType: ModelType | null
filteredModelType: FilterableModelType | null
): T[] => {
return data.filter((model) => {
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesType = filteredModelType ? model.type === filteredModelType : true;
const matchesType = getMatchesType(model, filteredModelType);
return matchesFilter && matchesType;
});
};
const getMatchesType = (modelConfig: AnyModelConfig, filteredModelType: FilterableModelType | null): boolean => {
if (filteredModelType === 'refiner') {
return modelConfig.base === 'sdxl-refiner';
}
if (filteredModelType === 'main' && modelConfig.base === 'sdxl-refiner') {
return false;
}
return filteredModelType ? modelConfig.type === filteredModelType : true;
};

View File

@ -13,6 +13,7 @@ export const ModelTypeFilter = () => {
const MODEL_TYPE_LABELS: Record<FilterableModelType, string> = useMemo(
() => ({
main: t('modelManager.main'),
refiner: t('sdxl.refiner'),
lora: 'LoRA',
embedding: t('modelManager.textualInversions'),
controlnet: 'ControlNet',

View File

@ -65,6 +65,11 @@ export const buildCanvasOutpaintGraph = async (
infillTileSize,
infillPatchmatchDownscaleSize,
infillMethod,
// infillMosaicTileWidth,
// infillMosaicTileHeight,
// infillMosaicMinColor,
// infillMosaicMaxColor,
infillColorValue,
clipSkip,
seamlessXAxis,
seamlessYAxis,
@ -356,6 +361,28 @@ export const buildCanvasOutpaintGraph = async (
};
}
// TODO: add mosaic back
// if (infillMethod === 'mosaic') {
// graph.nodes[INPAINT_INFILL] = {
// type: 'infill_mosaic',
// id: INPAINT_INFILL,
// is_intermediate,
// tile_width: infillMosaicTileWidth,
// tile_height: infillMosaicTileHeight,
// min_color: infillMosaicMinColor,
// max_color: infillMosaicMaxColor,
// };
// }
if (infillMethod === 'color') {
graph.nodes[INPAINT_INFILL] = {
type: 'infill_rgba',
id: INPAINT_INFILL,
color: infillColorValue,
is_intermediate,
};
}
// Handle Scale Before Processing
if (isUsingScaledDimensions) {
const scaledWidth: number = scaledBoundingBoxDimensions.width;

View File

@ -66,6 +66,11 @@ export const buildCanvasSDXLOutpaintGraph = async (
infillTileSize,
infillPatchmatchDownscaleSize,
infillMethod,
// infillMosaicTileWidth,
// infillMosaicTileHeight,
// infillMosaicMinColor,
// infillMosaicMaxColor,
infillColorValue,
seamlessXAxis,
seamlessYAxis,
canvasCoherenceMode,
@ -365,6 +370,28 @@ export const buildCanvasSDXLOutpaintGraph = async (
};
}
// TODO: add mosaic back
// if (infillMethod === 'mosaic') {
// graph.nodes[INPAINT_INFILL] = {
// type: 'infill_mosaic',
// id: INPAINT_INFILL,
// is_intermediate,
// tile_width: infillMosaicTileWidth,
// tile_height: infillMosaicTileHeight,
// min_color: infillMosaicMinColor,
// max_color: infillMosaicMaxColor,
// };
// }
if (infillMethod === 'color') {
graph.nodes[INPAINT_INFILL] = {
type: 'infill_rgba',
id: INPAINT_INFILL,
is_intermediate,
color: infillColorValue,
};
}
// Handle Scale Before Processing
if (isUsingScaledDimensions) {
const scaledWidth: number = scaledBoundingBoxDimensions.width;

View File

@ -0,0 +1,46 @@
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const ParamInfillColorOptions = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillColor: generation.infillColorValue,
})),
[]
);
const { infillColor } = useAppSelector(selector);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation();
const handleInfillColor = useCallback(
(v: RgbaColor) => {
dispatch(setInfillColorValue(v));
},
[dispatch]
);
return (
<Flex flexDir="column" gap={4}>
<FormControl isDisabled={infillMethod !== 'color'}>
<FormLabel>{t('parameters.infillColorValue')}</FormLabel>
<Box w="full" pt={2} pb={2}>
<IAIColorPicker color={infillColor} onChange={handleInfillColor} />
</Box>
</FormControl>
</Flex>
);
};
export default memo(ParamInfillColorOptions);

View File

@ -0,0 +1,127 @@
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import {
selectGenerationSlice,
setInfillMosaicMaxColor,
setInfillMosaicMinColor,
setInfillMosaicTileHeight,
setInfillMosaicTileWidth,
} from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const ParamInfillMosaicTileSize = () => {
const dispatch = useAppDispatch();
const selector = useMemo(
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillMosaicTileWidth: generation.infillMosaicTileWidth,
infillMosaicTileHeight: generation.infillMosaicTileHeight,
infillMosaicMinColor: generation.infillMosaicMinColor,
infillMosaicMaxColor: generation.infillMosaicMaxColor,
})),
[]
);
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
useAppSelector(selector);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation();
const handleInfillMosaicTileWidthChange = useCallback(
(v: number) => {
dispatch(setInfillMosaicTileWidth(v));
},
[dispatch]
);
const handleInfillMosaicTileHeightChange = useCallback(
(v: number) => {
dispatch(setInfillMosaicTileHeight(v));
},
[dispatch]
);
const handleInfillMosaicMinColor = useCallback(
(v: RgbaColor) => {
dispatch(setInfillMosaicMinColor(v));
},
[dispatch]
);
const handleInfillMosaicMaxColor = useCallback(
(v: RgbaColor) => {
dispatch(setInfillMosaicMaxColor(v));
},
[dispatch]
);
return (
<Flex flexDir="column" gap={4}>
<FormControl isDisabled={infillMethod !== 'mosaic'}>
<FormLabel>{t('parameters.infillMosaicTileWidth')}</FormLabel>
<CompositeSlider
min={8}
max={256}
value={infillMosaicTileWidth}
defaultValue={64}
onChange={handleInfillMosaicTileWidthChange}
step={8}
fineStep={8}
marks
/>
<CompositeNumberInput
min={8}
max={1024}
value={infillMosaicTileWidth}
defaultValue={64}
onChange={handleInfillMosaicTileWidthChange}
step={8}
fineStep={8}
/>
</FormControl>
<FormControl isDisabled={infillMethod !== 'mosaic'}>
<FormLabel>{t('parameters.infillMosaicTileHeight')}</FormLabel>
<CompositeSlider
min={8}
max={256}
value={infillMosaicTileHeight}
defaultValue={64}
onChange={handleInfillMosaicTileHeightChange}
step={8}
fineStep={8}
marks
/>
<CompositeNumberInput
min={8}
max={1024}
value={infillMosaicTileHeight}
defaultValue={64}
onChange={handleInfillMosaicTileHeightChange}
step={8}
fineStep={8}
/>
</FormControl>
<FormControl isDisabled={infillMethod !== 'mosaic'}>
<FormLabel>{t('parameters.infillMosaicMinColor')}</FormLabel>
<Box w="full" pt={2} pb={2}>
<IAIColorPicker color={infillMosaicMinColor} onChange={handleInfillMosaicMinColor} />
</Box>
</FormControl>
<FormControl isDisabled={infillMethod !== 'mosaic'}>
<FormLabel>{t('parameters.infillMosaicMaxColor')}</FormLabel>
<Box w="full" pt={2} pb={2}>
<IAIColorPicker color={infillMosaicMaxColor} onChange={handleInfillMosaicMaxColor} />
</Box>
</FormControl>
</Flex>
);
};
export default memo(ParamInfillMosaicTileSize);

View File

@ -1,6 +1,8 @@
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import ParamInfillColorOptions from './ParamInfillColorOptions';
import ParamInfillMosaicOptions from './ParamInfillMosaicOptions';
import ParamInfillPatchmatchDownscaleSize from './ParamInfillPatchmatchDownscaleSize';
import ParamInfillTilesize from './ParamInfillTilesize';
@ -14,6 +16,14 @@ const ParamInfillOptions = () => {
return <ParamInfillPatchmatchDownscaleSize />;
}
if (infillMethod === 'mosaic') {
return <ParamInfillMosaicOptions />;
}
if (infillMethod === 'color') {
return <ParamInfillColorOptions />;
}
return null;
};

View File

@ -19,6 +19,7 @@ import type {
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { configChanged } from 'features/system/store/configSlice';
import { clamp } from 'lodash-es';
import type { RgbaColor } from 'react-colorful';
import type { ImageDTO } from 'services/api/types';
import type { GenerationState } from './types';
@ -43,8 +44,6 @@ const initialGenerationState: GenerationState = {
shouldFitToWidthHeight: true,
shouldRandomizeSeed: true,
steps: 50,
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
width: 512,
model: null,
vae: null,
@ -55,6 +54,13 @@ const initialGenerationState: GenerationState = {
shouldUseCpuNoise: true,
shouldShowAdvancedOptions: false,
aspectRatio: { ...initialAspectRatioState },
infillTileSize: 32,
infillPatchmatchDownscaleSize: 1,
infillMosaicTileWidth: 64,
infillMosaicTileHeight: 64,
infillMosaicMinColor: { r: 0, g: 0, b: 0, a: 1 },
infillMosaicMaxColor: { r: 255, g: 255, b: 255, a: 1 },
infillColorValue: { r: 0, g: 0, b: 0, a: 1 },
};
export const generationSlice = createSlice({
@ -116,15 +122,6 @@ export const generationSlice = createSlice({
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
state.canvasCoherenceMinDenoise = action.payload;
},
setInfillMethod: (state, action: PayloadAction<string>) => {
state.infillMethod = action.payload;
},
setInfillTileSize: (state, action: PayloadAction<number>) => {
state.infillTileSize = action.payload;
},
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
state.infillPatchmatchDownscaleSize = action.payload;
},
initialImageChanged: (state, action: PayloadAction<ImageDTO>) => {
const { image_name, width, height } = action.payload;
state.initialImage = { imageName: image_name, width, height };
@ -206,6 +203,30 @@ export const generationSlice = createSlice({
aspectRatioChanged: (state, action: PayloadAction<AspectRatioState>) => {
state.aspectRatio = action.payload;
},
setInfillMethod: (state, action: PayloadAction<string>) => {
state.infillMethod = action.payload;
},
setInfillTileSize: (state, action: PayloadAction<number>) => {
state.infillTileSize = action.payload;
},
setInfillPatchmatchDownscaleSize: (state, action: PayloadAction<number>) => {
state.infillPatchmatchDownscaleSize = action.payload;
},
setInfillMosaicTileWidth: (state, action: PayloadAction<number>) => {
state.infillMosaicTileWidth = action.payload;
},
setInfillMosaicTileHeight: (state, action: PayloadAction<number>) => {
state.infillMosaicTileHeight = action.payload;
},
setInfillMosaicMinColor: (state, action: PayloadAction<RgbaColor>) => {
state.infillMosaicMinColor = action.payload;
},
setInfillMosaicMaxColor: (state, action: PayloadAction<RgbaColor>) => {
state.infillMosaicMaxColor = action.payload;
},
setInfillColorValue: (state, action: PayloadAction<RgbaColor>) => {
state.infillColorValue = action.payload;
},
},
extraReducers: (builder) => {
builder.addCase(configChanged, (state, action) => {
@ -249,8 +270,6 @@ export const {
setShouldFitToWidthHeight,
setShouldRandomizeSeed,
setSteps,
setInfillTileSize,
setInfillPatchmatchDownscaleSize,
initialImageChanged,
modelChanged,
vaeSelected,
@ -264,6 +283,13 @@ export const {
heightChanged,
widthRecalled,
heightRecalled,
setInfillTileSize,
setInfillPatchmatchDownscaleSize,
setInfillMosaicTileWidth,
setInfillMosaicTileHeight,
setInfillMosaicMinColor,
setInfillMosaicMaxColor,
setInfillColorValue,
} = generationSlice.actions;
export const { selectOptimalDimension } = generationSlice.selectors;

View File

@ -17,6 +17,7 @@ import type {
ParameterVAEModel,
ParameterWidth,
} from 'features/parameters/types/parameterSchemas';
import type { RgbaColor } from 'react-colorful';
export interface GenerationState {
_version: 2;
@ -39,8 +40,6 @@ export interface GenerationState {
shouldFitToWidthHeight: boolean;
shouldRandomizeSeed: boolean;
steps: ParameterSteps;
infillTileSize: number;
infillPatchmatchDownscaleSize: number;
width: ParameterWidth;
model: ParameterModel | null;
vae: ParameterVAEModel | null;
@ -51,6 +50,13 @@ export interface GenerationState {
shouldUseCpuNoise: boolean;
shouldShowAdvancedOptions: boolean;
aspectRatio: AspectRatioState;
infillTileSize: number;
infillPatchmatchDownscaleSize: number;
infillMosaicTileWidth: number;
infillMosaicTileHeight: number;
infillMosaicMinColor: RgbaColor;
infillMosaicMaxColor: RgbaColor;
infillColorValue: RgbaColor;
}
export type PayloadActionWithOptimalDimension<T = void> = PayloadAction<T, string, { optimalDimension: number }>;

File diff suppressed because one or more lines are too long