mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
a1773197e9
- remove `image_origin` from most places where we interact with images - consolidate image file storage into a single `images/` dir Images have an `image_origin` attribute but it is not actually used when retrieving images, nor will it ever be. It is still used when creating images and helps to differentiate between internally generated images and uploads. It was included in eg API routes and image service methods as a holdover from the previous app implementation where images were not managed in a database. Now that we have images in a db, we can do away with this and simplify basically everything that touches images. The one potentially controversial change is to no longer separate internal and external images on disk. If we retain this separation, we have to keep `image_origin` around in a number of spots and it getting image paths on disk painful. So, I am have gotten rid of this organisation. Images are now all stored in `images`, regardless of their origin. As we improve the image management features, this change will hopefully become transparent.
548 lines
18 KiB
Python
548 lines
18 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
import io
|
|
from typing import Literal, Optional, Union
|
|
|
|
import numpy
|
|
from PIL import Image, ImageFilter, ImageOps, ImageChops
|
|
from pydantic import BaseModel, Field
|
|
|
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
InvocationContext,
|
|
InvocationConfig,
|
|
)
|
|
|
|
|
|
class PILInvocationConfig(BaseModel):
|
|
"""Helper class to provide all PIL invocations with additional config"""
|
|
|
|
class Config(InvocationConfig):
|
|
schema_extra = {
|
|
"ui": {
|
|
"tags": ["PIL", "image"],
|
|
},
|
|
}
|
|
|
|
|
|
class ImageOutput(BaseInvocationOutput):
|
|
"""Base class for invocations that output an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["image_output"] = "image_output"
|
|
image: ImageField = Field(default=None, description="The output image")
|
|
width: int = Field(description="The width of the image in pixels")
|
|
height: int = Field(description="The height of the image in pixels")
|
|
# fmt: on
|
|
|
|
class Config:
|
|
schema_extra = {"required": ["type", "image", "width", "height"]}
|
|
|
|
|
|
class MaskOutput(BaseInvocationOutput):
|
|
"""Base class for invocations that output a mask"""
|
|
|
|
# fmt: off
|
|
type: Literal["mask"] = "mask"
|
|
mask: ImageField = Field(default=None, description="The output mask")
|
|
width: int = Field(description="The width of the mask in pixels")
|
|
height: int = Field(description="The height of the mask in pixels")
|
|
# fmt: on
|
|
|
|
class Config:
|
|
schema_extra = {
|
|
"required": [
|
|
"type",
|
|
"mask",
|
|
]
|
|
}
|
|
|
|
|
|
class LoadImageInvocation(BaseInvocation):
|
|
"""Load an image and provide it as output."""
|
|
|
|
# fmt: off
|
|
type: Literal["load_image"] = "load_image"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(
|
|
default=None, description="The image to load"
|
|
)
|
|
# fmt: on
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=self.image.image_name),
|
|
width=image.width,
|
|
height=image.height,
|
|
)
|
|
|
|
|
|
class ShowImageInvocation(BaseInvocation):
|
|
"""Displays a provided image, and passes it forward in the pipeline."""
|
|
|
|
type: Literal["show_image"] = "show_image"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(
|
|
default=None, description="The image to show"
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
if image:
|
|
image.show()
|
|
|
|
# TODO: how to handle failure?
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=self.image.image_name),
|
|
width=image.width,
|
|
height=image.height,
|
|
)
|
|
|
|
|
|
class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Crops an image to a specified box. The box can be outside of the image."""
|
|
|
|
# fmt: off
|
|
type: Literal["img_crop"] = "img_crop"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to crop")
|
|
x: int = Field(default=0, description="The left x coordinate of the crop rectangle")
|
|
y: int = Field(default=0, description="The top y coordinate of the crop rectangle")
|
|
width: int = Field(default=512, gt=0, description="The width of the crop rectangle")
|
|
height: int = Field(default=512, gt=0, description="The height of the crop rectangle")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
image_crop = Image.new(
|
|
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
|
|
)
|
|
image_crop.paste(image, (-self.x, -self.y))
|
|
|
|
image_dto = context.services.images.create(
|
|
image=image_crop,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Pastes an image into another image."""
|
|
|
|
# fmt: off
|
|
type: Literal["img_paste"] = "img_paste"
|
|
|
|
# Inputs
|
|
base_image: Union[ImageField, None] = Field(default=None, description="The base image")
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to paste")
|
|
mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting")
|
|
x: int = Field(default=0, description="The left x coordinate at which to paste the image")
|
|
y: int = Field(default=0, description="The top y coordinate at which to paste the image")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
base_image = context.services.images.get_pil_image(self.base_image.image_name)
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
mask = (
|
|
None
|
|
if self.mask is None
|
|
else ImageOps.invert(
|
|
context.services.images.get_pil_image(self.mask.image_name)
|
|
)
|
|
)
|
|
# TODO: probably shouldn't invert mask here... should user be required to do it?
|
|
|
|
min_x = min(0, self.x)
|
|
min_y = min(0, self.y)
|
|
max_x = max(base_image.width, image.width + self.x)
|
|
max_y = max(base_image.height, image.height + self.y)
|
|
|
|
new_image = Image.new(
|
|
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
|
|
)
|
|
new_image.paste(base_image, (abs(min_x), abs(min_y)))
|
|
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=new_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Extracts the alpha channel of an image as a mask."""
|
|
|
|
# fmt: off
|
|
type: Literal["tomask"] = "tomask"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from")
|
|
invert: bool = Field(default=False, description="Whether or not to invert the mask")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
image_mask = image.split()[-1]
|
|
if self.invert:
|
|
image_mask = ImageOps.invert(image_mask)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=image_mask,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.MASK,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return MaskOutput(
|
|
mask=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
|
|
|
# fmt: off
|
|
type: Literal["img_mul"] = "img_mul"
|
|
|
|
# Inputs
|
|
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
|
|
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image1 = context.services.images.get_pil_image(self.image1.image_name)
|
|
image2 = context.services.images.get_pil_image(self.image2.image_name)
|
|
|
|
multiply_image = ImageChops.multiply(image1, image2)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=multiply_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
|
|
|
|
|
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Gets a channel from an image."""
|
|
|
|
# fmt: off
|
|
type: Literal["img_chan"] = "img_chan"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
|
|
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
channel_image = image.getchannel(self.channel)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=channel_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
|
|
|
|
|
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Converts an image to a different mode."""
|
|
|
|
# fmt: off
|
|
type: Literal["img_conv"] = "img_conv"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
|
|
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
converted_image = image.convert(self.mode)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=converted_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Blurs an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["img_blur"] = "img_blur"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to blur")
|
|
radius: float = Field(default=8.0, ge=0, description="The blur radius")
|
|
blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
blur = (
|
|
ImageFilter.GaussianBlur(self.radius)
|
|
if self.blur_type == "gaussian"
|
|
else ImageFilter.BoxBlur(self.radius)
|
|
)
|
|
blur_image = image.filter(blur)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=blur_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
PIL_RESAMPLING_MODES = Literal[
|
|
"nearest",
|
|
"box",
|
|
"bilinear",
|
|
"hamming",
|
|
"bicubic",
|
|
"lanczos",
|
|
]
|
|
|
|
|
|
PIL_RESAMPLING_MAP = {
|
|
"nearest": Image.Resampling.NEAREST,
|
|
"box": Image.Resampling.BOX,
|
|
"bilinear": Image.Resampling.BILINEAR,
|
|
"hamming": Image.Resampling.HAMMING,
|
|
"bicubic": Image.Resampling.BICUBIC,
|
|
"lanczos": Image.Resampling.LANCZOS,
|
|
}
|
|
|
|
|
|
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Resizes an image to specific dimensions"""
|
|
|
|
# fmt: off
|
|
type: Literal["img_resize"] = "img_resize"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
|
|
|
resize_image = image.resize(
|
|
(self.width, self.height),
|
|
resample=resample_mode,
|
|
)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=resize_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Scales an image by a factor"""
|
|
|
|
# fmt: off
|
|
type: Literal["img_scale"] = "img_scale"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
|
width = int(image.width * self.scale_factor)
|
|
height = int(image.height * self.scale_factor)
|
|
|
|
resize_image = image.resize(
|
|
(width, height),
|
|
resample=resample_mode,
|
|
)
|
|
|
|
image_dto = context.services.images.create(
|
|
image=resize_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Linear interpolation of all pixels of an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["img_lerp"] = "img_lerp"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
|
min: int = Field(default=0, ge=0, le=255, description="The minimum output value")
|
|
max: int = Field(default=255, ge=0, le=255, description="The maximum output value")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
|
|
image_arr = image_arr * (self.max - self.min) + self.max
|
|
|
|
lerp_image = Image.fromarray(numpy.uint8(image_arr))
|
|
|
|
image_dto = context.services.images.create(
|
|
image=lerp_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|
"""Inverse linear interpolation of all pixels of an image"""
|
|
|
|
# fmt: off
|
|
type: Literal["img_ilerp"] = "img_ilerp"
|
|
|
|
# Inputs
|
|
image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
|
|
min: int = Field(default=0, ge=0, le=255, description="The minimum input value")
|
|
max: int = Field(default=255, ge=0, le=255, description="The maximum input value")
|
|
# fmt: on
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
image = context.services.images.get_pil_image(self.image.image_name)
|
|
|
|
image_arr = numpy.asarray(image, dtype=numpy.float32)
|
|
image_arr = (
|
|
numpy.minimum(
|
|
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
|
|
)
|
|
* 255
|
|
)
|
|
|
|
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
|
|
|
|
image_dto = context.services.images.create(
|
|
image=ilerp_image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|