mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move pil_to_tensor() and tensor_to_pil() utilities to the SpandrelImageToImage class.
This commit is contained in:
@ -1,6 +1,4 @@
|
|||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||||
from invokeai.app.invocations.fields import (
|
from invokeai.app.invocations.fields import (
|
||||||
@ -17,44 +15,6 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
|||||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||||
|
|
||||||
|
|
||||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
|
||||||
"""Convert PIL Image to torch.Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
|
||||||
"""
|
|
||||||
image_np = np.array(image)
|
|
||||||
# (H, W, C) -> (C, H, W)
|
|
||||||
image_np = np.transpose(image_np, (2, 0, 1))
|
|
||||||
image_np = image_np / 255
|
|
||||||
image_tensor = torch.from_numpy(image_np).float()
|
|
||||||
# (C, H, W) -> (N, C, H, W)
|
|
||||||
image_tensor = image_tensor.unsqueeze(0)
|
|
||||||
return image_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
|
||||||
"""Convert torch.Tensor to PIL Image.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
|
||||||
"""
|
|
||||||
# (N, C, H, W) -> (C, H, W)
|
|
||||||
tensor = tensor.squeeze(0)
|
|
||||||
# (C, H, W) -> (H, W, C)
|
|
||||||
tensor = tensor.permute(1, 2, 0)
|
|
||||||
tensor = tensor.clamp(0, 1)
|
|
||||||
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
|
||||||
image = Image.fromarray(tensor)
|
|
||||||
return image
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0")
|
@invocation("upscale_spandrel", title="Upscale (spandrel)", tags=["upscale"], category="upscale", version="1.0.0")
|
||||||
class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||||
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
|
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
|
||||||
@ -75,13 +35,13 @@ class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||||
|
|
||||||
# Prepare input image for inference.
|
# Prepare input image for inference.
|
||||||
image_tensor = pil_to_tensor(image)
|
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||||
|
|
||||||
# Run inference.
|
# Run inference.
|
||||||
image_tensor = spandrel_model.run(image_tensor)
|
image_tensor = spandrel_model.run(image_tensor)
|
||||||
|
|
||||||
# Convert the output tensor to a PIL image.
|
# Convert the output tensor to a PIL image.
|
||||||
pil_image = tensor_to_pil(image_tensor)
|
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
|
||||||
image_dto = context.images.save(image=pil_image)
|
image_dto = context.images.save(image=pil_image)
|
||||||
return ImageOutput.build(image_dto)
|
return ImageOutput.build(image_dto)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from spandrel import ImageModelDescriptor, ModelLoader
|
from spandrel import ImageModelDescriptor, ModelLoader
|
||||||
|
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
@ -16,8 +18,50 @@ class SpandrelImageToImageModel(RawModel):
|
|||||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||||
self._spandrel_model = spandrel_model
|
self._spandrel_model = spandrel_model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||||
|
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||||
|
"""
|
||||||
|
image_np = np.array(image)
|
||||||
|
# (H, W, C) -> (C, H, W)
|
||||||
|
image_np = np.transpose(image_np, (2, 0, 1))
|
||||||
|
image_np = image_np / 255
|
||||||
|
image_tensor = torch.from_numpy(image_np).float()
|
||||||
|
# (C, H, W) -> (N, C, H, W)
|
||||||
|
image_tensor = image_tensor.unsqueeze(0)
|
||||||
|
return image_tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
||||||
|
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||||
|
"""
|
||||||
|
# (N, C, H, W) -> (C, H, W)
|
||||||
|
tensor = tensor.squeeze(0)
|
||||||
|
# (C, H, W) -> (H, W, C)
|
||||||
|
tensor = tensor.permute(1, 2, 0)
|
||||||
|
tensor = tensor.clamp(0, 1)
|
||||||
|
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
||||||
|
image = Image.fromarray(tensor)
|
||||||
|
return image
|
||||||
|
|
||||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
"""Run the image-to-image model."""
|
"""Run the image-to-image model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||||
|
"""
|
||||||
return self._spandrel_model(image_tensor)
|
return self._spandrel_model(image_tensor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Reference in New Issue
Block a user