mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
from pathlib import Path
|
|
from typing import Any, Optional
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from spandrel import ImageModelDescriptor, ModelLoader
|
|
|
|
from invokeai.backend.raw_model import RawModel
|
|
|
|
|
|
class SpandrelImageToImageModel(RawModel):
|
|
"""A wrapper for a Spandrel Image-to-Image model.
|
|
|
|
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
|
"""
|
|
|
|
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
|
self._spandrel_model = spandrel_model
|
|
|
|
@staticmethod
|
|
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
|
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
|
|
|
Args:
|
|
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
|
|
|
Returns:
|
|
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
|
"""
|
|
image_np = np.array(image)
|
|
# (H, W, C) -> (C, H, W)
|
|
image_np = np.transpose(image_np, (2, 0, 1))
|
|
image_np = image_np / 255
|
|
image_tensor = torch.from_numpy(image_np).float()
|
|
# (C, H, W) -> (N, C, H, W)
|
|
image_tensor = image_tensor.unsqueeze(0)
|
|
return image_tensor
|
|
|
|
@staticmethod
|
|
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
|
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
|
|
|
Args:
|
|
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
|
|
|
Returns:
|
|
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
|
"""
|
|
# (N, C, H, W) -> (C, H, W)
|
|
tensor = tensor.squeeze(0)
|
|
# (C, H, W) -> (H, W, C)
|
|
tensor = tensor.permute(1, 2, 0)
|
|
tensor = tensor.clamp(0, 1)
|
|
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
|
image = Image.fromarray(tensor)
|
|
return image
|
|
|
|
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
|
"""Run the image-to-image model.
|
|
|
|
Args:
|
|
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
|
"""
|
|
return self._spandrel_model(image_tensor)
|
|
|
|
@classmethod
|
|
def load_from_file(cls, file_path: str | Path):
|
|
model = ModelLoader().load_from_file(file_path)
|
|
if not isinstance(model, ImageModelDescriptor):
|
|
raise ValueError(
|
|
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
|
"('ImageModelDescriptor')."
|
|
)
|
|
|
|
return cls(spandrel_model=model)
|
|
|
|
@classmethod
|
|
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
|
model = ModelLoader().load_from_state_dict(state_dict)
|
|
if not isinstance(model, ImageModelDescriptor):
|
|
raise ValueError(
|
|
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
|
"('ImageModelDescriptor')."
|
|
)
|
|
|
|
return cls(spandrel_model=model)
|
|
|
|
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
|
"""Check if the model supports the given dtype."""
|
|
if dtype == torch.float16:
|
|
return self._spandrel_model.supports_half
|
|
elif dtype == torch.bfloat16:
|
|
return self._spandrel_model.supports_bfloat16
|
|
elif dtype == torch.float32:
|
|
# All models support float32.
|
|
return True
|
|
else:
|
|
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
|
|
|
def get_model_type_name(self) -> str:
|
|
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
|
consistent over time.
|
|
"""
|
|
return str(type(self._spandrel_model.model))
|
|
|
|
def to(
|
|
self,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
non_blocking: bool = False,
|
|
) -> None:
|
|
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
|
Note: The non_blocking parameter is currently ignored."""
|
|
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
|
|
# model directly if we want to apply this optimization.
|
|
self._spandrel_model.to(device=device, dtype=dtype)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
"""The device of the underlying model."""
|
|
return self._spandrel_model.device
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
"""The dtype of the underlying model."""
|
|
return self._spandrel_model.dtype
|