from pathlib import Path from typing import Any, Optional import numpy as np import torch from PIL import Image from spandrel import ImageModelDescriptor, ModelLoader from invokeai.backend.raw_model import RawModel class SpandrelImageToImageModel(RawModel): """A wrapper for a Spandrel Image-to-Image model. The main reason for having a wrapper class is to integrate with the type handling of RawModel. """ def __init__(self, spandrel_model: ImageModelDescriptor[Any]): self._spandrel_model = spandrel_model @staticmethod def pil_to_tensor(image: Image.Image) -> torch.Tensor: """Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run(). Args: image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255]. Returns: torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. """ image_np = np.array(image) # (H, W, C) -> (C, H, W) image_np = np.transpose(image_np, (2, 0, 1)) image_np = image_np / 255 image_tensor = torch.from_numpy(image_np).float() # (C, H, W) -> (N, C, H, W) image_tensor = image_tensor.unsqueeze(0) return image_tensor @staticmethod def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: """Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image. Args: tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. Returns: Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255]. """ # (N, C, H, W) -> (C, H, W) tensor = tensor.squeeze(0) # (C, H, W) -> (H, W, C) tensor = tensor.permute(1, 2, 0) tensor = tensor.clamp(0, 1) tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8) image = Image.fromarray(tensor) return image def run(self, image_tensor: torch.Tensor) -> torch.Tensor: """Run the image-to-image model. Args: image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. """ return self._spandrel_model(image_tensor) @classmethod def load_from_file(cls, file_path: str | Path): model = ModelLoader().load_from_file(file_path) if not isinstance(model, ImageModelDescriptor): raise ValueError( f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " "('ImageModelDescriptor')." ) return cls(spandrel_model=model) @classmethod def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]): model = ModelLoader().load_from_state_dict(state_dict) if not isinstance(model, ImageModelDescriptor): raise ValueError( f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " "('ImageModelDescriptor')." ) return cls(spandrel_model=model) def supports_dtype(self, dtype: torch.dtype) -> bool: """Check if the model supports the given dtype.""" if dtype == torch.float16: return self._spandrel_model.supports_half elif dtype == torch.bfloat16: return self._spandrel_model.supports_bfloat16 elif dtype == torch.float32: # All models support float32. return True else: raise ValueError(f"Unexpected dtype '{dtype}'.") def get_model_type_name(self) -> str: """The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining consistent over time. """ return str(type(self._spandrel_model.model)) def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False, ) -> None: """Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported. Note: The non_blocking parameter is currently ignored.""" # TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model # directly if we want to apply this optimization. self._spandrel_model.to(device=device, dtype=dtype) @property def device(self) -> torch.device: """The device of the underlying model.""" return self._spandrel_model.device @property def dtype(self) -> torch.dtype: """The dtype of the underlying model.""" return self._spandrel_model.dtype