mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add resize image and scale image nodes
This commit is contained in:
parent
bce33ea62e
commit
bbb4e8f5ef
@ -416,6 +416,115 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
PIL_RESAMPLING_MODES = Literal[
|
||||||
|
"nearest",
|
||||||
|
"box",
|
||||||
|
"bilinear",
|
||||||
|
"hamming",
|
||||||
|
"bicubic",
|
||||||
|
"lanczos",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
PIL_RESAMPLING_MAP = {
|
||||||
|
"nearest": Image.Resampling.NEAREST,
|
||||||
|
"box": Image.Resampling.BOX,
|
||||||
|
"bilinear": Image.Resampling.BILINEAR,
|
||||||
|
"hamming": Image.Resampling.HAMMING,
|
||||||
|
"bicubic": Image.Resampling.BICUBIC,
|
||||||
|
"lanczos": Image.Resampling.LANCZOS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Resizes an image to specific dimensions"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_resize"] = "img_resize"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to resize")
|
||||||
|
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
|
||||||
|
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
|
||||||
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_origin, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
|
||||||
|
resize_image = image.resize(
|
||||||
|
(self.width, self.height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=resize_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_origin=image_dto.image_origin,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
|
"""Scales an image by a factor"""
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
type: Literal["img_scale"] = "img_scale"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: Union[ImageField, None] = Field(default=None, description="The image to scale")
|
||||||
|
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
||||||
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
image = context.services.images.get_pil_image(
|
||||||
|
self.image.image_origin, self.image.image_name
|
||||||
|
)
|
||||||
|
|
||||||
|
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||||
|
width = int(image.width * self.scale_factor)
|
||||||
|
height = int(image.height * self.scale_factor)
|
||||||
|
|
||||||
|
resize_image = image.resize(
|
||||||
|
(width, height),
|
||||||
|
resample=resample_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=resize_image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
is_intermediate=self.is_intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(
|
||||||
|
image_name=image_dto.image_name,
|
||||||
|
image_origin=image_dto.image_origin,
|
||||||
|
),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Linear interpolation of all pixels of an image"""
|
"""Linear interpolation of all pixels of an image"""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user