Use a ModelIdentifierField to identify the spandrel model in the UpscaleSpandrelInvocation.

This commit is contained in:
Ryan Dick 2024-06-28 15:30:35 -04:00
parent 2a1514272f
commit 95079dc7d4
2 changed files with 36 additions and 29 deletions

View File

@ -1,13 +1,20 @@
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
def pil_to_tensor(image: Image.Image) -> torch.Tensor: def pil_to_tensor(image: Image.Image) -> torch.Tensor:
@ -53,40 +60,26 @@ class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel).""" """Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
image: ImageField = InputField(description="The input image") image: ImageField = InputField(description="The input image")
# TODO(ryand): Figure out how to handle all the spandrel models so that you don't have to enter a string. spandrel_image_to_image_model: ModelIdentifierField = InputField(
model_path: str = InputField(description="The path to the upscaling model to use.") description=FieldDescriptions.spandrel_image_to_image_model, ui_type=UIType.LoRAModel
)
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name) image = context.images.get_pil(self.image.image_name)
# Load the model. # Load the model.
# TODO(ryand): Integrate with the model manager. spandrel_model_info = context.models.load(self.spandrel_image_to_image_model)
model = ModelLoader().load_from_file(self.model_path)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
# Select model device and dtype. with spandrel_model_info as spandrel_model:
torch_dtype = TorchDevice.choose_torch_dtype() assert isinstance(spandrel_model, SpandrelImageToImageModel)
torch_device = TorchDevice.choose_torch_device()
if (torch_dtype == torch.float16 and not model.supports_half) or (
torch_dtype == torch.bfloat16 and not model.supports_bfloat16
):
context.logger.warning(
f"The configured dtype ('{torch_dtype}') is not supported by the {type(model.model)} model. Falling "
"back to 'float32'."
)
torch_dtype = torch.float32
model.to(device=torch_device, dtype=torch_dtype)
# Prepare input image for inference. # Prepare input image for inference.
image_tensor = pil_to_tensor(image) image_tensor = pil_to_tensor(image)
image_tensor = image_tensor.to(device=torch_device, dtype=torch_dtype) image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run inference. # Run inference.
image_tensor = model(image_tensor) image_tensor = spandrel_model.run(image_tensor)
# Convert the output tensor to a PIL image. # Convert the output tensor to a PIL image.
pil_image = tensor_to_pil(image_tensor) pil_image = tensor_to_pil(image_tensor)

View File

@ -16,6 +16,10 @@ class SpandrelImageToImageModel(RawModel):
def __init__(self, spandrel_model: ImageModelDescriptor[Any]): def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
self._spandrel_model = spandrel_model self._spandrel_model = spandrel_model
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Run the image-to-image model."""
return self._spandrel_model(image_tensor)
@classmethod @classmethod
def load_from_file(cls, file_path: str | Path): def load_from_file(cls, file_path: str | Path):
model = ModelLoader().load_from_file(file_path) model = ModelLoader().load_from_file(file_path)
@ -67,3 +71,13 @@ class SpandrelImageToImageModel(RawModel):
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model # TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model
# directly if we want to apply this optimization. # directly if we want to apply this optimization.
self._spandrel_model.to(device=device, dtype=dtype) self._spandrel_model.to(device=device, dtype=dtype)
@property
def device(self) -> torch.device:
"""The device of the underlying model."""
return self._spandrel_model.device
@property
def dtype(self) -> torch.dtype:
"""The dtype of the underlying model."""
return self._spandrel_model.dtype