feat(nodes): split out spandrel node upscale logic into utils

This commit is contained in:
psychedelicious 2024-07-23 05:56:45 +10:00
parent 13f3560e55
commit a2ef5d56ee

View File

@ -1,3 +1,5 @@
from typing import Callable
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
@ -35,7 +37,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling." default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
) )
def _scale_tile(self, tile: Tile, scale: int) -> Tile: @classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
return Tile( return Tile(
coords=TBLR( coords=TBLR(
top=tile.coords.top * scale, top=tile.coords.top * scale,
@ -51,20 +54,22 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
), ),
) )
@torch.inference_mode() @classmethod
def invoke(self, context: InvocationContext) -> ImageOutput: def upscale_image(
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to cls,
# revisit this. image: Image.Image,
image = context.images.get_pil(self.image.image_name, mode="RGB") tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
) -> Image.Image:
# Compute the image tiles. # Compute the image tiles.
if self.tile_size > 0: if tile_size > 0:
min_overlap = 20 min_overlap = 20
tiles = calc_tiles_min_overlap( tiles = calc_tiles_min_overlap(
image_height=image.height, image_height=image.height,
image_width=image.width, image_width=image.width,
tile_height=self.tile_size, tile_height=tile_size,
tile_width=self.tile_size, tile_width=tile_size,
min_overlap=min_overlap, min_overlap=min_overlap,
) )
else: else:
@ -85,6 +90,63 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# Prepare input image for inference. # Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
# Scale the tiles for re-assembling the final image.
scale = spandrel_model.scale
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
return pil_image
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model. # Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model) spandrel_model_info = context.models.load(self.image_to_image_model)
@ -92,53 +154,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Scale the tiles for re-assembling the final image. pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
scale = spandrel_model.scale
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if context.util.is_canceled():
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)