InvokeAI/invokeai/app/invocations/infill.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

173 lines
6.8 KiB
Python
Raw Normal View History

2024-03-20 03:17:16 +00:00
from abc import abstractmethod
from typing import Literal, get_args
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
from PIL import Image
2023-05-05 05:16:26 +00:00
from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
feat(ui): add support for custom field types Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen
2023-11-17 00:32:35 +00:00
from invokeai.app.util.misc import SEED_MAX
2024-03-20 03:17:16 +00:00
from invokeai.backend.image_util.infill_methods.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.infill_methods.lama import LaMA
from invokeai.backend.image_util.infill_methods.mosaic import infill_mosaic
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch, infill_patchmatch
from invokeai.backend.image_util.infill_methods.tile import infill_tile
from invokeai.backend.util.logging import InvokeAILogger
2023-05-05 05:16:26 +00:00
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
logger = InvokeAILogger.get_logger()
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
def get_infill_methods():
methods = Literal["tile", "color", "lama", "cv2"] # TODO: add mosaic back
2023-05-05 05:16:26 +00:00
if PatchMatch.patchmatch_available():
methods = Literal["patchmatch", "tile", "color", "lama", "cv2"] # TODO: add mosaic back
2023-05-05 05:16:26 +00:00
return methods
2024-03-20 03:17:16 +00:00
INFILL_METHODS = get_infill_methods()
2023-05-05 05:16:26 +00:00
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
2024-03-20 03:17:16 +00:00
class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Base class for invocations that preprocess images for Infilling"""
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
image: ImageField = InputField(description="The image to process")
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
2024-04-01 08:30:55 +00:00
"""Infill the image with the specified method"""
pass
2023-05-05 05:16:26 +00:00
def load_image(self) -> tuple[Image.Image, bool]:
2024-03-20 03:17:16 +00:00
"""Process the image to have an alpha channel before being infilled"""
image = self._context.images.get_pil(self.image.image_name)
2024-03-20 03:17:16 +00:00
has_alpha = True if image.mode == "RGBA" else False
return image, has_alpha
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
2024-03-20 03:17:16 +00:00
# Retrieve and process image to be infilled
input_image, has_alpha = self.load_image()
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
# If the input image has no alpha channel, return it
if has_alpha is False:
return ImageOutput.build(context.images.get_dto(self.image.image_name))
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
# Perform Infill action
infilled_image = self.infill(input_image)
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
# Return Infilled Image
return ImageOutput.build(infilled_image_dto)
2023-05-05 05:16:26 +00:00
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
2024-03-20 03:17:16 +00:00
class InfillColorInvocation(InfillImageProcessorInvocation):
2023-05-06 09:36:51 +00:00
"""Infills transparent areas of an image with a solid color"""
2023-05-05 05:16:26 +00:00
color: ColorField = InputField(
2023-05-05 05:16:26 +00:00
default=ColorField(r=127, g=127, b=127, a=255),
2023-05-06 09:06:39 +00:00
description="The color to use to infill",
2023-05-05 05:16:26 +00:00
)
2023-05-06 09:06:39 +00:00
def infill(self, image: Image.Image):
2023-05-06 09:06:39 +00:00
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
2023-05-06 09:06:39 +00:00
infilled.paste(image, (0, 0), image.split()[-1])
2024-03-20 03:17:16 +00:00
return infilled
2023-05-06 09:06:39 +00:00
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.3")
2024-03-20 03:17:16 +00:00
class InfillTileInvocation(InfillImageProcessorInvocation):
2023-05-06 09:06:39 +00:00
"""Infills transparent areas of an image with tiles of the image"""
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
seed: int = InputField(
feat(ui): add support for custom field types Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen
2023-11-17 00:32:35 +00:00
default=0,
2023-05-06 09:06:39 +00:00
ge=0,
le=SEED_MAX,
description="The seed to use for tile generation (omit for random)",
2023-05-05 05:16:26 +00:00
)
def infill(self, image: Image.Image):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
2023-05-06 09:06:39 +00:00
@invocation(
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2"
)
2024-03-20 03:17:16 +00:00
class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
2023-05-06 09:36:51 +00:00
"""Infills transparent areas of an image using the PatchMatch algorithm"""
2023-05-06 09:06:39 +00:00
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
2023-07-18 14:26:45 +00:00
def infill(self, image: Image.Image):
2023-09-01 20:08:46 +00:00
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale)
height = int(image.height / self.downscale)
2024-03-20 03:17:16 +00:00
infilled = image.resize(
2023-09-01 20:08:46 +00:00
(width, height),
resample=resample_mode,
)
2024-03-20 03:17:16 +00:00
infilled = infill_patchmatch(image)
2023-09-01 20:08:46 +00:00
infilled = infilled.resize(
(image.width, image.height),
resample=resample_mode,
)
infilled.paste(image, (0, 0), mask=image.split()[-1])
2023-05-05 05:16:26 +00:00
2024-03-20 03:17:16 +00:00
return infilled
2023-08-23 19:25:24 +00:00
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
2024-03-20 03:17:16 +00:00
class LaMaInfillInvocation(InfillImageProcessorInvocation):
2023-08-23 19:25:24 +00:00
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
return lama(image)
2023-09-01 16:48:18 +00:00
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
2024-03-20 03:17:16 +00:00
class CV2InfillInvocation(InfillImageProcessorInvocation):
2023-09-01 16:48:18 +00:00
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
2024-03-20 03:17:16 +00:00
return cv2_inpaint(image)
2023-09-01 16:48:18 +00:00
# @invocation(
# "infill_mosaic", title="Mosaic Infill", tags=["image", "inpaint", "outpaint"], category="inpaint", version="1.0.0"
# )
2024-03-20 03:17:16 +00:00
class MosaicInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image with a mosaic pattern drawing colors from the rest of the image"""
2023-09-01 16:48:18 +00:00
2024-03-20 03:17:16 +00:00
image: ImageField = InputField(description="The image to infill")
tile_width: int = InputField(default=64, description="Width of the tile")
tile_height: int = InputField(default=64, description="Height of the tile")
min_color: ColorField = InputField(
default=ColorField(r=0, g=0, b=0, a=255),
description="The min threshold for color",
)
max_color: ColorField = InputField(
default=ColorField(r=255, g=255, b=255, a=255),
description="The max threshold for color",
)
2023-09-01 16:48:18 +00:00
def infill(self, image: Image.Image):
2024-03-20 03:17:16 +00:00
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())