feat(nodes): add scale and fit_to_multiple_of_8 to spandrel node

This commit is contained in:
psychedelicious 2024-07-23 08:40:51 +10:00
parent a2ef5d56ee
commit ac6adc392a

View File

@ -23,7 +23,7 @@ from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.2.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
@ -36,6 +36,16 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
scale: float = InputField(
default=1.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
@ -102,6 +112,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if is_canceled():
@ -150,11 +161,62 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# Run the model on each tile.
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
# to be considered an upscale model.
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
break
else:
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
# to be the same as the processed image size.
# The output size is now the size of the processed image.
target_width = pil_image.width
target_height = pil_image.height
# Warn the user if they requested a scale greater than 1.
if self.scale > 1:
context.logger.warning(
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
)
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
# in the final resize
if self.fit_to_multiple_of_8:
target_width = int(target_width // 8 * 8)
target_height = int(target_height // 8 * 8)
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)