mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Reduce peak VRAM utilization of SpandrelImageToImageInvocation.
This commit is contained in:
parent
d868d5d584
commit
d0d2955992
@ -1,4 +1,6 @@
|
|||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
@ -48,29 +50,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor):
|
|
||||||
"""A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we
|
|
||||||
split the overlap in half. No 'blending' is applied.
|
|
||||||
"""
|
|
||||||
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
|
||||||
# iterate over tiles left-to-right, top-to-bottom.
|
|
||||||
tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True))
|
|
||||||
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left)
|
|
||||||
tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top)
|
|
||||||
|
|
||||||
for tile, tile_tensor in tiles_and_tensors:
|
|
||||||
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge
|
|
||||||
# artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems
|
|
||||||
# unnecessary, but we may find a need in the future.
|
|
||||||
top_overlap = tile.overlap.top // 2
|
|
||||||
left_overlap = tile.overlap.left // 2
|
|
||||||
out_tensor[
|
|
||||||
:,
|
|
||||||
:,
|
|
||||||
tile.coords.top + top_overlap : tile.coords.bottom,
|
|
||||||
tile.coords.left + left_overlap : tile.coords.right,
|
|
||||||
] = tile_tensor[:, :, top_overlap:, left_overlap:]
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
|
||||||
@ -97,6 +76,11 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
|
||||||
|
# over tiles left-to-right, top-to-bottom.
|
||||||
|
tiles = sorted(tiles, key=lambda x: x.coords.left)
|
||||||
|
tiles = sorted(tiles, key=lambda x: x.coords.top)
|
||||||
|
|
||||||
# Prepare input image for inference.
|
# Prepare input image for inference.
|
||||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||||
|
|
||||||
@ -104,8 +88,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||||
|
|
||||||
# Run the model on each tile.
|
# Run the model on each tile.
|
||||||
output_tiles: list[torch.Tensor] = []
|
|
||||||
scale: int = 1
|
|
||||||
with spandrel_model_info as spandrel_model:
|
with spandrel_model_info as spandrel_model:
|
||||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||||
|
|
||||||
@ -113,27 +95,45 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
scale = spandrel_model.scale
|
scale = spandrel_model.scale
|
||||||
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
|
||||||
|
|
||||||
|
# Prepare the output tensor.
|
||||||
|
_, channels, height, width = image_tensor.shape
|
||||||
|
output_tensor = torch.zeros(
|
||||||
|
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
|
||||||
|
)
|
||||||
|
|
||||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||||
|
|
||||||
for tile in tqdm(tiles, desc="Upscaling Tiles"):
|
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
|
||||||
output_tile = spandrel_model.run(
|
# Extract the current tile from the input tensor.
|
||||||
image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
|
input_tile = image_tensor[
|
||||||
)
|
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
|
||||||
output_tiles.append(output_tile)
|
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||||
|
|
||||||
# TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue:
|
# Run the model on the tile.
|
||||||
# - Keep the input tensor on the CPU.
|
output_tile = spandrel_model.run(input_tile)
|
||||||
# - Move each tile to the GPU as it is processed.
|
|
||||||
# - Move output tensors back to the CPU as they are produced, and merge them into the output tensor.
|
|
||||||
|
|
||||||
# Merge the tiles to an output tensor.
|
# Convert the output tile into the output tensor's format.
|
||||||
batch_size, channels, height, width = image_tensor.shape
|
# (N, C, H, W) -> (C, H, W)
|
||||||
output_tensor = torch.zeros(
|
output_tile = output_tile.squeeze(0)
|
||||||
(batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device
|
# (C, H, W) -> (H, W, C)
|
||||||
)
|
output_tile = output_tile.permute(1, 2, 0)
|
||||||
self._merge_tiles(scaled_tiles, output_tiles, output_tensor)
|
output_tile = output_tile.clamp(0, 1)
|
||||||
|
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
|
||||||
|
|
||||||
|
# Merge the output tile into the output tensor.
|
||||||
|
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
|
||||||
|
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
|
||||||
|
# it seems unnecessary, but we may find a need in the future.
|
||||||
|
top_overlap = scaled_tile.overlap.top // 2
|
||||||
|
left_overlap = scaled_tile.overlap.left // 2
|
||||||
|
output_tensor[
|
||||||
|
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
|
||||||
|
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
|
||||||
|
:,
|
||||||
|
] = output_tile[top_overlap:, left_overlap:, :]
|
||||||
|
|
||||||
# Convert the output tensor to a PIL image.
|
# Convert the output tensor to a PIL image.
|
||||||
pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor)
|
np_image = output_tensor.detach().numpy().astype(np.uint8)
|
||||||
|
pil_image = Image.fromarray(np_image)
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
Loading…
Reference in New Issue
Block a user