mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use a ModelIdentifierField to identify the spandrel model in the UpscaleSpandrelInvocation.
This commit is contained in:
parent
2a1514272f
commit
95079dc7d4
@ -1,13 +1,20 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||
@ -53,40 +60,26 @@ class UpscaleSpandrelInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Upscales an image using any upscaler supported by spandrel (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
# TODO(ryand): Figure out how to handle all the spandrel models so that you don't have to enter a string.
|
||||
model_path: str = InputField(description="The path to the upscaling model to use.")
|
||||
spandrel_image_to_image_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.spandrel_image_to_image_model, ui_type=UIType.LoRAModel
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Load the model.
|
||||
# TODO(ryand): Integrate with the model manager.
|
||||
model = ModelLoader().load_from_file(self.model_path)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
spandrel_model_info = context.models.load(self.spandrel_image_to_image_model)
|
||||
|
||||
# Select model device and dtype.
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
if (torch_dtype == torch.float16 and not model.supports_half) or (
|
||||
torch_dtype == torch.bfloat16 and not model.supports_bfloat16
|
||||
):
|
||||
context.logger.warning(
|
||||
f"The configured dtype ('{torch_dtype}') is not supported by the {type(model.model)} model. Falling "
|
||||
"back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(device=torch_device, dtype=torch_dtype)
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Prepare input image for inference.
|
||||
image_tensor = pil_to_tensor(image)
|
||||
image_tensor = image_tensor.to(device=torch_device, dtype=torch_dtype)
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run inference.
|
||||
image_tensor = model(image_tensor)
|
||||
image_tensor = spandrel_model.run(image_tensor)
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
pil_image = tensor_to_pil(image_tensor)
|
||||
|
@ -16,6 +16,10 @@ class SpandrelImageToImageModel(RawModel):
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model."""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
@ -67,3 +71,13 @@ class SpandrelImageToImageModel(RawModel):
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model
|
||||
# directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
Loading…
Reference in New Issue
Block a user