diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 7633bfbc16..d048410468 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -416,6 +416,115 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): ) +PIL_RESAMPLING_MODES = Literal[ + "nearest", + "box", + "bilinear", + "hamming", + "bicubic", + "lanczos", +] + + +PIL_RESAMPLING_MAP = { + "nearest": Image.Resampling.NEAREST, + "box": Image.Resampling.BOX, + "bilinear": Image.Resampling.BILINEAR, + "hamming": Image.Resampling.HAMMING, + "bicubic": Image.Resampling.BICUBIC, + "lanczos": Image.Resampling.LANCZOS, +} + + +class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): + """Resizes an image to specific dimensions""" + + # fmt: off + type: Literal["img_resize"] = "img_resize" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to resize") + width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + + resize_image = image.resize( + (self.width, self.height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + +class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): + """Scales an image by a factor""" + + # fmt: off + type: Literal["img_scale"] = "img_scale" + + # Inputs + image: Union[ImageField, None] = Field(default=None, description="The image to scale") + scale_factor: float = Field(gt=0, description="The factor by which to scale the image") + resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") + # fmt: on + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.services.images.get_pil_image( + self.image.image_origin, self.image.image_name + ) + + resample_mode = PIL_RESAMPLING_MAP[self.resample_mode] + width = int(image.width * self.scale_factor) + height = int(image.height * self.scale_factor) + + resize_image = image.resize( + (width, height), + resample=resample_mode, + ) + + image_dto = context.services.images.create( + image=resize_image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + is_intermediate=self.is_intermediate, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_origin=image_dto.image_origin, + ), + width=image_dto.width, + height=image_dto.height, + ) + + class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): """Linear interpolation of all pixels of an image"""