2024-07-05 18:57:05 +00:00
|
|
|
import torch
|
2024-07-09 21:52:28 +00:00
|
|
|
from tqdm import tqdm
|
2024-07-05 18:57:05 +00:00
|
|
|
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
|
|
from invokeai.app.invocations.fields import (
|
|
|
|
FieldDescriptions,
|
|
|
|
ImageField,
|
|
|
|
InputField,
|
|
|
|
UIType,
|
|
|
|
WithBoard,
|
|
|
|
WithMetadata,
|
|
|
|
)
|
|
|
|
from invokeai.app.invocations.model import ModelIdentifierField
|
|
|
|
from invokeai.app.invocations.primitives import ImageOutput
|
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
|
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
2024-07-10 16:25:00 +00:00
|
|
|
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
2024-07-09 21:52:28 +00:00
|
|
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
2024-07-05 18:57:05 +00:00
|
|
|
|
|
|
|
|
2024-07-09 21:52:28 +00:00
|
|
|
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
|
2024-07-05 18:57:05 +00:00
|
|
|
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|
|
|
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
|
|
|
|
|
|
|
image: ImageField = InputField(description="The input image")
|
|
|
|
image_to_image_model: ModelIdentifierField = InputField(
|
|
|
|
title="Image-to-Image Model",
|
|
|
|
description=FieldDescriptions.spandrel_image_to_image_model,
|
|
|
|
ui_type=UIType.SpandrelImageToImageModel,
|
|
|
|
)
|
2024-07-09 21:52:28 +00:00
|
|
|
tile_size: int = InputField(
|
|
|
|
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
|
|
|
)
|
|
|
|
|
|
|
|
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
|
|
|
return Tile(
|
|
|
|
coords=TBLR(
|
|
|
|
top=tile.coords.top * scale,
|
|
|
|
bottom=tile.coords.bottom * scale,
|
|
|
|
left=tile.coords.left * scale,
|
|
|
|
right=tile.coords.right * scale,
|
|
|
|
),
|
|
|
|
overlap=TBLR(
|
|
|
|
top=tile.overlap.top * scale,
|
|
|
|
bottom=tile.overlap.bottom * scale,
|
|
|
|
left=tile.overlap.left * scale,
|
|
|
|
right=tile.overlap.right * scale,
|
|
|
|
),
|
|
|
|
)
|
2024-07-05 18:57:05 +00:00
|
|
|
|
2024-07-10 16:25:00 +00:00
|
|
|
def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor):
|
|
|
|
"""A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we
|
|
|
|
split the overlap in half. No 'blending' is applied.
|
|
|
|
"""
|
|
|
|
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
|
|
|
# iterate over tiles left-to-right, top-to-bottom.
|
|
|
|
tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True))
|
|
|
|
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left)
|
|
|
|
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top)
|
|
|
|
|
|
|
|
for tile, tile_tensor in tiles_and_tensors:
|
|
|
|
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge
|
|
|
|
# artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems
|
|
|
|
# unnecessary, but we may find a need in the future.
|
|
|
|
top_overlap = tile.overlap.top // 2
|
|
|
|
left_overlap = tile.overlap.left // 2
|
|
|
|
out_tensor[
|
|
|
|
:,
|
|
|
|
:,
|
|
|
|
tile.coords.top + top_overlap : tile.coords.bottom,
|
|
|
|
tile.coords.left + left_overlap : tile.coords.right,
|
|
|
|
] = tile_tensor[:, :, top_overlap:, left_overlap:]
|
|
|
|
|
2024-07-05 18:57:05 +00:00
|
|
|
@torch.inference_mode()
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
2024-07-09 21:52:28 +00:00
|
|
|
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
|
|
|
# revisit this.
|
|
|
|
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
|
|
|
|
|
|
|
# Compute the image tiles.
|
|
|
|
if self.tile_size > 0:
|
|
|
|
min_overlap = 20
|
|
|
|
tiles = calc_tiles_min_overlap(
|
|
|
|
image_height=image.height,
|
|
|
|
image_width=image.width,
|
|
|
|
tile_height=self.tile_size,
|
|
|
|
tile_width=self.tile_size,
|
|
|
|
min_overlap=min_overlap,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
# No tiling. Generate a single tile that covers the entire image.
|
|
|
|
min_overlap = 0
|
|
|
|
tiles = [
|
|
|
|
Tile(
|
|
|
|
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
|
|
|
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
|
|
|
)
|
|
|
|
]
|
|
|
|
|
|
|
|
# Prepare input image for inference.
|
|
|
|
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
2024-07-05 18:57:05 +00:00
|
|
|
|
|
|
|
# Load the model.
|
|
|
|
spandrel_model_info = context.models.load(self.image_to_image_model)
|
|
|
|
|
2024-07-09 21:52:28 +00:00
|
|
|
# Run the model on each tile.
|
|
|
|
output_tiles: list[torch.Tensor] = []
|
|
|
|
scale: int = 1
|
2024-07-05 18:57:05 +00:00
|
|
|
with spandrel_model_info as spandrel_model:
|
|
|
|
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
|
|
|
|
2024-07-09 21:52:28 +00:00
|
|
|
# Scale the tiles for re-assembling the final image.
|
|
|
|
scale = spandrel_model.scale
|
|
|
|
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
|
|
|
|
2024-07-05 18:57:05 +00:00
|
|
|
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
|
|
|
|
2024-07-09 21:52:28 +00:00
|
|
|
for tile in tqdm(tiles, desc="Upscaling Tiles"):
|
|
|
|
output_tile = spandrel_model.run(
|
|
|
|
image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
|
|
|
|
)
|
|
|
|
output_tiles.append(output_tile)
|
|
|
|
|
2024-07-10 16:25:00 +00:00
|
|
|
# TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue:
|
|
|
|
# - Keep the input tensor on the CPU.
|
|
|
|
# - Move each tile to the GPU as it is processed.
|
|
|
|
# - Move output tensors back to the CPU as they are produced, and merge them into the output tensor.
|
|
|
|
|
|
|
|
# Merge the tiles to an output tensor.
|
|
|
|
batch_size, channels, height, width = image_tensor.shape
|
|
|
|
output_tensor = torch.zeros(
|
|
|
|
(batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device
|
2024-07-09 21:52:28 +00:00
|
|
|
)
|
2024-07-10 16:25:00 +00:00
|
|
|
self._merge_tiles(scaled_tiles, output_tiles, output_tensor)
|
2024-07-05 18:57:05 +00:00
|
|
|
|
|
|
|
# Convert the output tensor to a PIL image.
|
2024-07-10 16:25:00 +00:00
|
|
|
pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor)
|
2024-07-05 18:57:05 +00:00
|
|
|
image_dto = context.images.save(image=pil_image)
|
|
|
|
return ImageOutput.build(image_dto)
|