feat(nodes): add image mul, channel, convert nodes

also make img node names consistent
This commit is contained in:
psychedelicious 2023-05-24 21:35:46 +10:00 committed by Kent Keirsey
parent 66ad04fcfc
commit 460d555a3d

View File

@ -4,7 +4,7 @@ import io
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ImageType from ..models.image import ImageCategory, ImageField, ImageType
@ -112,11 +112,11 @@ class ShowImageInvocation(BaseInvocation):
) )
class CropImageInvocation(BaseInvocation, PILInvocationConfig): class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
"""Crops an image to a specified box. The box can be outside of the image.""" """Crops an image to a specified box. The box can be outside of the image."""
# fmt: off # fmt: off
type: Literal["crop"] = "crop" type: Literal["img_crop"] = "img_crop"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to crop") image: Union[ImageField, None] = Field(default=None, description="The image to crop")
@ -154,11 +154,11 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig):
) )
class PasteImageInvocation(BaseInvocation, PILInvocationConfig): class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
"""Pastes an image into another image.""" """Pastes an image into another image."""
# fmt: off # fmt: off
type: Literal["paste"] = "paste" type: Literal["img_paste"] = "img_paste"
# Inputs # Inputs
base_image: Union[ImageField, None] = Field(default=None, description="The base image") base_image: Union[ImageField, None] = Field(default=None, description="The base image")
@ -238,7 +238,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_mask, image=image_mask,
image_type=ImageType.INTERMEDIATE, image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.MASK,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
) )
@ -252,11 +252,124 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
) )
class BlurInvocation(BaseInvocation, PILInvocationConfig): class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
# fmt: off
type: Literal["img_mul"] = "img_mul"
# Inputs
image1: Union[ImageField, None] = Field(default=None, description="The first image to multiply")
image2: Union[ImageField, None] = Field(default=None, description="The second image to multiply")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name
)
image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name
)
multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create(
image=multiply_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
"""Gets a channel from an image."""
# fmt: off
type: Literal["img_chan"] = "img_chan"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to get the channel from")
channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create(
image=channel_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
IMAGE_MODES = Literal['L', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F']
class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
"""Converts an image to a different mode."""
# fmt: off
type: Literal["img_conv"] = "img_conv"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to convert")
mode: IMAGE_MODES = Field(default="L", description="The mode to convert to")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
converted_image = image.convert(self.mode)
image_dto = context.services.images.create(
image=converted_image,
image_type=ImageType.INTERMEDIATE,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name
),
width=image_dto.width,
height=image_dto.height,
)
class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Blurs an image""" """Blurs an image"""
# fmt: off # fmt: off
type: Literal["blur"] = "blur" type: Literal["img_blur"] = "img_blur"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to blur") image: Union[ImageField, None] = Field(default=None, description="The image to blur")
@ -294,11 +407,11 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig):
) )
class LerpInvocation(BaseInvocation, PILInvocationConfig): class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image""" """Linear interpolation of all pixels of an image"""
# fmt: off # fmt: off
type: Literal["lerp"] = "lerp" type: Literal["img_lerp"] = "img_lerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Union[ImageField, None] = Field(default=None, description="The image to lerp")
@ -334,11 +447,11 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig):
) )
class InverseLerpInvocation(BaseInvocation, PILInvocationConfig): class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Inverse linear interpolation of all pixels of an image""" """Inverse linear interpolation of all pixels of an image"""
# fmt: off # fmt: off
type: Literal["ilerp"] = "ilerp" type: Literal["img_ilerp"] = "img_ilerp"
# Inputs # Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to lerp") image: Union[ImageField, None] = Field(default=None, description="The image to lerp")