mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): split out spandrel node upscale logic into utils
This commit is contained in:
parent
13f3560e55
commit
a2ef5d56ee
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Callable
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -35,7 +37,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
|
@classmethod
|
||||||
|
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
|
||||||
return Tile(
|
return Tile(
|
||||||
coords=TBLR(
|
coords=TBLR(
|
||||||
top=tile.coords.top * scale,
|
top=tile.coords.top * scale,
|
||||||
@ -51,20 +54,22 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@classmethod
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def upscale_image(
|
||||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
cls,
|
||||||
# revisit this.
|
image: Image.Image,
|
||||||
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
tile_size: int,
|
||||||
|
spandrel_model: SpandrelImageToImageModel,
|
||||||
|
is_canceled: Callable[[], bool],
|
||||||
|
) -> Image.Image:
|
||||||
# Compute the image tiles.
|
# Compute the image tiles.
|
||||||
if self.tile_size > 0:
|
if tile_size > 0:
|
||||||
min_overlap = 20
|
min_overlap = 20
|
||||||
tiles = calc_tiles_min_overlap(
|
tiles = calc_tiles_min_overlap(
|
||||||
image_height=image.height,
|
image_height=image.height,
|
||||||
image_width=image.width,
|
image_width=image.width,
|
||||||
tile_height=self.tile_size,
|
tile_height=tile_size,
|
||||||
tile_width=self.tile_size,
|
tile_width=tile_size,
|
||||||
min_overlap=min_overlap,
|
min_overlap=min_overlap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -85,16 +90,9 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# Prepare input image for inference.
|
# Prepare input image for inference.
|
||||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||||
|
|
||||||
# Load the model.
|
|
||||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
|
||||||
|
|
||||||
# Run the model on each tile.
|
|
||||||
with spandrel_model_info as spandrel_model:
|
|
||||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
|
||||||
|
|
||||||
# Scale the tiles for re-assembling the final image.
|
# Scale the tiles for re-assembling the final image.
|
||||||
scale = spandrel_model.scale
|
scale = spandrel_model.scale
|
||||||
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
|
||||||
|
|
||||||
# Prepare the output tensor.
|
# Prepare the output tensor.
|
||||||
_, channels, height, width = image_tensor.shape
|
_, channels, height, width = image_tensor.shape
|
||||||
@ -106,7 +104,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||||
# Exit early if the invocation has been canceled.
|
# Exit early if the invocation has been canceled.
|
||||||
if context.util.is_canceled():
|
if is_canceled():
|
||||||
raise CanceledException
|
raise CanceledException
|
||||||
|
|
||||||
# Extract the current tile from the input tensor.
|
# Extract the current tile from the input tensor.
|
||||||
@ -140,5 +138,23 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# Convert the output tensor to a PIL image.
|
# Convert the output tensor to a PIL image.
|
||||||
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||||
pil_image = Image.fromarray(np_image)
|
pil_image = Image.fromarray(np_image)
|
||||||
|
|
||||||
|
return pil_image
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||||
|
# revisit this.
|
||||||
|
image = context.images.get_pil(self.image.image_name, mode="RGB")
|
||||||
|
|
||||||
|
# Load the model.
|
||||||
|
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||||
|
|
||||||
|
# Run the model on each tile.
|
||||||
|
with spandrel_model_info as spandrel_model:
|
||||||
|
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||||
|
|
||||||
|
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
|
||||||
|
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
Loading…
Reference in New Issue
Block a user