feat(nodes): add resize image and scale image nodes

This commit is contained in:
psychedelicious 2023-05-29 17:51:21 +10:00 committed by Kent Keirsey
parent bce33ea62e
commit bbb4e8f5ef

View File

@ -416,6 +416,115 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
)
PIL_RESAMPLING_MODES = Literal[
"nearest",
"box",
"bilinear",
"hamming",
"bicubic",
"lanczos",
]
PIL_RESAMPLING_MAP = {
"nearest": Image.Resampling.NEAREST,
"box": Image.Resampling.BOX,
"bilinear": Image.Resampling.BILINEAR,
"hamming": Image.Resampling.HAMMING,
"bicubic": Image.Resampling.BICUBIC,
"lanczos": Image.Resampling.LANCZOS,
}
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
"""Resizes an image to specific dimensions"""
# fmt: off
type: Literal["img_resize"] = "img_resize"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
resize_image = image.resize(
(self.width, self.height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
"""Scales an image by a factor"""
# fmt: off
type: Literal["img_scale"] = "img_scale"
# Inputs
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(
self.image.image_origin, self.image.image_name
)
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width * self.scale_factor)
height = int(image.height * self.scale_factor)
resize_image = image.resize(
(width, height),
resample=resample_mode,
)
image_dto = context.services.images.create(
image=resize_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(
image_name=image_dto.image_name,
image_origin=image_dto.image_origin,
),
width=image_dto.width,
height=image_dto.height,
)
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
"""Linear interpolation of all pixels of an image"""