Move pil_to_tensor() and tensor_to_pil() utilities to the SpandrelImageToImage class.

This commit is contained in:
Ryan Dick
2024-07-02 10:11:25 -04:00
parent 1ab20f43c8
commit 6161aa73af
2 changed files with 47 additions and 43 deletions

View File

@ -1,6 +1,4 @@
import numpy as np
import torch import torch
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -17,44 +15,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to torch.Tensor.
Args:
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
Returns:
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
image_np = np.array(image)
# (H, W, C) -> (C, H, W)
image_np = np.transpose(image_np, (2, 0, 1))
image_np = image_np / 255
image_tensor = torch.from_numpy(image_np).float()
# (C, H, W) -> (N, C, H, W)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
"""Convert torch.Tensor to PIL Image.
Args:
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
Returns:
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
"""
# (N, C, H, W) -> (C, H, W)
tensor = tensor.squeeze(0)
# (C, H, W) -> (H, W, C)
tensor = tensor.permute(1, 2, 0)
tensor = tensor.clamp(0, 1)
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0") @invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0")
class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard): class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel).""" """Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
@ -75,13 +35,13 @@ class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Prepare input image for inference. # Prepare input image for inference.
image_tensor = pil_to_tensor(image) image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run inference. # Run inference.
image_tensor = spandrel_model.run(image_tensor) image_tensor = spandrel_model.run(image_tensor)
# Convert the output tensor to a PIL image. # Convert the output tensor to a PIL image.
pil_image = tensor_to_pil(image_tensor) pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)

View File

@ -1,7 +1,9 @@
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
import numpy as np
import torch import torch
from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.backend.raw_model import RawModel from invokeai.backend.raw_model import RawModel
@ -16,8 +18,50 @@ class SpandrelImageToImageModel(RawModel):
def __init__(self, spandrel_model: ImageModelDescriptor[Any]): def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
self._spandrel_model = spandrel_model self._spandrel_model = spandrel_model
@staticmethod
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
Args:
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
Returns:
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
image_np = np.array(image)
# (H, W, C) -> (C, H, W)
image_np = np.transpose(image_np, (2, 0, 1))
image_np = image_np / 255
image_tensor = torch.from_numpy(image_np).float()
# (C, H, W) -> (N, C, H, W)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
@staticmethod
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
Args:
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
Returns:
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
"""
# (N, C, H, W) -> (C, H, W)
tensor = tensor.squeeze(0)
# (C, H, W) -> (H, W, C)
tensor = tensor.permute(1, 2, 0)
tensor = tensor.clamp(0, 1)
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
def run(self, image_tensor: torch.Tensor) -> torch.Tensor: def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Run the image-to-image model.""" """Run the image-to-image model.
Args:
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
return self._spandrel_model(image_tensor) return self._spandrel_model(image_tensor)
@classmethod @classmethod