From c1afe3570423da43077ec91c452a187bf055024a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 27 Jun 2024 15:17:22 -0400 Subject: [PATCH] Add prototype invocation for running upscaling models with spandrel. --- invokeai/app/invocations/spandrel_upscale.py | 94 ++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 95 insertions(+) create mode 100644 invokeai/app/invocations/spandrel_upscale.py diff --git a/invokeai/app/invocations/spandrel_upscale.py b/invokeai/app/invocations/spandrel_upscale.py new file mode 100644 index 0000000000..babf399cd6 --- /dev/null +++ b/invokeai/app/invocations/spandrel_upscale.py @@ -0,0 +1,94 @@ +import numpy as np +import torch +from PIL import Image +from spandrel import ImageModelDescriptor, ModelLoader + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.util.devices import TorchDevice + + +def pil_to_tensor(image: Image.Image) -> torch.Tensor: + """Convert PIL Image to torch.Tensor. + + Args: + image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255]. + + Returns: + torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. + """ + image_np = np.array(image) + # (H, W, C) -> (C, H, W) + image_np = np.transpose(image_np, (2, 0, 1)) + image_np = image_np / 255 + image_tensor = torch.from_numpy(image_np).float() + # (C, H, W) -> (N, C, H, W) + image_tensor = image_tensor.unsqueeze(0) + return image_tensor + + +def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + """Convert torch.Tensor to PIL Image. + + Args: + tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. + + Returns: + Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255]. + """ + # (N, C, H, W) -> (C, H, W) + tensor = tensor.squeeze(0) + # (C, H, W) -> (H, W, C) + tensor = tensor.permute(1, 2, 0) + tensor = tensor.clamp(0, 1) + tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8) + image = Image.fromarray(tensor) + return image + + +@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0") +class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard): + """Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel).""" + + image: ImageField = InputField(description="The input image") + # TODO(ryand): Figure out how to handle all the spandrel models so that you don't have to enter a string. + model_path: str = InputField(description="The path to the upscaling model to use.") + + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + + # Load the model. + # TODO(ryand): Integrate with the model manager. + model = ModelLoader().load_from_file(self.model_path) + if not isinstance(model, ImageModelDescriptor): + raise ValueError( + f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " + "('ImageModelDescriptor')." + ) + + # Select model device and dtype. + torch_dtype = TorchDevice.choose_torch_dtype() + torch_device = TorchDevice.choose_torch_device() + if (torch_dtype == torch.float16 and not model.supports_half) or ( + torch_dtype == torch.bfloat16 and not model.supports_bfloat16 + ): + context.logger.warning( + f"The configured dtype ('{torch_dtype}') is not supported by the {type(model.model)} model. Falling " + "back to 'float32'." + ) + torch_dtype = torch.float32 + model.to(device=torch_device, dtype=torch_dtype) + + # Prepare input image for inference. + image_tensor = pil_to_tensor(image) + image_tensor = image_tensor.to(device=torch_device, dtype=torch_dtype) + + # Run inference. + image_tensor = model(image_tensor) + + # Convert the output tensor to a PIL image. + pil_image = tensor_to_pil(image_tensor) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/pyproject.toml b/pyproject.toml index fcc0aff60c..fa716254de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "opencv-python==4.9.0.80", "pytorch-lightning==2.1.3", "safetensors==0.4.3", + "spandrel==0.3.4", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "torch==2.2.2", "torchmetrics==0.11.4",