Add prototype invocation for running upscaling models with spandrel.

This commit is contained in:
Ryan Dick 2024-06-27 15:17:22 -04:00
parent e4813f800a
commit c1afe35704
2 changed files with 95 additions and 0 deletions

View File

@ -0,0 +1,94 @@
import numpy as np
import torch
from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to torch.Tensor.
Args:
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
Returns:
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
image_np = np.array(image)
# (H, W, C) -> (C, H, W)
image_np = np.transpose(image_np, (2, 0, 1))
image_np = image_np / 255
image_tensor = torch.from_numpy(image_np).float()
# (C, H, W) -> (N, C, H, W)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
"""Convert torch.Tensor to PIL Image.
Args:
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
Returns:
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
"""
# (N, C, H, W) -> (C, H, W)
tensor = tensor.squeeze(0)
# (C, H, W) -> (H, W, C)
tensor = tensor.permute(1, 2, 0)
tensor = tensor.clamp(0, 1)
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0")
class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
image: ImageField = InputField(description="The input image")
# TODO(ryand): Figure out how to handle all the spandrel models so that you don't have to enter a string.
model_path: str = InputField(description="The path to the upscaling model to use.")
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Load the model.
# TODO(ryand): Integrate with the model manager.
model = ModelLoader().load_from_file(self.model_path)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
# Select model device and dtype.
torch_dtype = TorchDevice.choose_torch_dtype()
torch_device = TorchDevice.choose_torch_device()
if (torch_dtype == torch.float16 and not model.supports_half) or (
torch_dtype == torch.bfloat16 and not model.supports_bfloat16
):
context.logger.warning(
f"The configured dtype ('{torch_dtype}') is not supported by the {type(model.model)} model. Falling "
"back to 'float32'."
)
torch_dtype = torch.float32
model.to(device=torch_device, dtype=torch_dtype)
# Prepare input image for inference.
image_tensor = pil_to_tensor(image)
image_tensor = image_tensor.to(device=torch_device, dtype=torch_dtype)
# Run inference.
image_tensor = model(image_tensor)
# Convert the output tensor to a PIL image.
pil_image = tensor_to_pil(image_tensor)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

View File

@ -46,6 +46,7 @@ dependencies = [
"opencv-python==4.9.0.80", "opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3", "pytorch-lightning==2.1.3",
"safetensors==0.4.3", "safetensors==0.4.3",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2", "torch==2.2.2",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",