mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for Spandrel Image-to-Image models (e.g. ESRGAN, Real-ESRGAN, Swin-IR, DAT, etc.) (#6556)
## Summary - Add support for all [spandrel](https://github.com/chaiNNer-org/spandrel) image-to-image models - this is a collection of many popular super-resolution models (e.g. ESRGAN, Real-ESRGAN, SwinIR, DAT, etc.) Examples of supported models: - DAT: https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM - SwinIR: https://github.com/JingyunLiang/SwinIR/releases - Any ESRGAN / Real-ESRGAN model ## Related Issues Closes #6394 ## QA Instructions - [x] Test that unsupported models still fail the probe (i.e. no false positive spandrel models) - [x] Test adding a few non-spandrel model types - [x] Test adding a handful of spandrel model types: ESRGAN, Real-ESRGAN, SwinIR, DAT - [x] Verify model size estimation for the model cache - [x] Test using the spandrel models in a practical image upscaling workflow ## Merge Plan - [x] Get approval from @brandonrising and @maryhipp before merging - this PR has commercial implications. - [x] Merge #6571 and change the target branch to `main` ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
7ad32dcad2
@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
T2IAdapterModel = "T2IAdapterModelField"
|
||||
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
@ -134,6 +135,7 @@ class FieldDescriptions:
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
spandrel_image_to_image_model = "Image-to-Image model"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
|
49
invokeai/app/invocations/spandrel_image_to_image.py
Normal file
49
invokeai/app/invocations/spandrel_image_to_image.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
|
||||
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
image_to_image_model: ModelIdentifierField = InputField(
|
||||
title="Image-to-Image Model",
|
||||
description=FieldDescriptions.spandrel_image_to_image_model,
|
||||
ui_type=UIType.SpandrelImageToImageModel,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Load the model.
|
||||
spandrel_model_info = context.models.load(self.image_to_image_model)
|
||||
|
||||
with spandrel_model_info as spandrel_model:
|
||||
assert isinstance(spandrel_model, SpandrelImageToImageModel)
|
||||
|
||||
# Prepare input image for inference.
|
||||
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
|
||||
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
|
||||
|
||||
# Run inference.
|
||||
image_tensor = spandrel_model.run(image_tensor)
|
||||
|
||||
# Convert the output tensor to a PIL image.
|
||||
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
@ -67,6 +67,7 @@ class ModelType(str, Enum):
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
T2IAdapter = "t2i_adapter"
|
||||
SpandrelImageToImage = "spandrel_image_to_image"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
||||
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
||||
|
||||
|
||||
class SpandrelImageToImageConfig(ModelConfigBase):
|
||||
"""Model config for Spandrel Image to Image models."""
|
||||
|
||||
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||
|
||||
@staticmethod
|
||||
def get_tag() -> Tag:
|
||||
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
||||
|
||||
|
||||
def get_model_discriminator_value(v: Any) -> str:
|
||||
"""
|
||||
Computes the discriminator value for a model config.
|
||||
@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
||||
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||
],
|
||||
Discriminator(get_model_discriminator_value),
|
||||
|
@ -0,0 +1,45 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class SpandrelImageToImageModelLoader(ModelLoader):
|
||||
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if submodel_type is not None:
|
||||
raise ValueError("Unexpected submodel requested for Spandrel model.")
|
||||
|
||||
model_path = Path(config.path)
|
||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
|
||||
torch_dtype = self._torch_dtype
|
||||
if not model.supports_dtype(torch_dtype):
|
||||
self._logger.warning(
|
||||
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||
"model. Falling back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(dtype=torch_dtype)
|
||||
|
||||
return model
|
@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import spandrel
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
CkptType = Dict[str | int, Any]
|
||||
@ -220,24 +222,46 @@ class ModelProbe(object):
|
||||
ckpt = ckpt.get("state_dict", ckpt)
|
||||
|
||||
for key in [str(k) for k in ckpt.keys()]:
|
||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
|
||||
return ModelType.Main
|
||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
|
||||
return ModelType.VAE
|
||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||
elif key.startswith(("lora_te_", "lora_unet_")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
|
||||
return ModelType.LoRA
|
||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
return ModelType.ControlNet
|
||||
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
elif key in {"emb_params", "string_to_param"}:
|
||||
return ModelType.TextualInversion
|
||||
else:
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# diffusers-ti
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||
try:
|
||||
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||
# explored to avoid this:
|
||||
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||
# supported on meta tensors.
|
||||
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||
# maintain it, and the risk of false positive detections is higher.
|
||||
SpandrelImageToImageModel.load_from_file(model_path)
|
||||
return ModelType.SpandrelImageToImage
|
||||
except spandrel.UnsupportedModelError:
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
if "No such file or directory" in str(e):
|
||||
# This error is expected if the model_path does not exist (which is the case in some unit tests).
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class SpandrelImageToImageFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||
@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
@ -1,15 +1,3 @@
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
@ -17,7 +5,17 @@ import torch
|
||||
|
||||
|
||||
class RawModel(ABC):
|
||||
"""Abstract base class for 'Raw' model wrappers."""
|
||||
"""Base class for 'Raw' models.
|
||||
|
||||
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
|
||||
and is used for type checking of calls to the model patcher. Its main purpose
|
||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||
from lora.py.
|
||||
|
||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||
that adds additional methods and attributes.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||
|
134
invokeai/backend/spandrel_image_to_image_model.py
Normal file
134
invokeai/backend/spandrel_image_to_image_model.py
Normal file
@ -0,0 +1,134 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from spandrel import ImageModelDescriptor, ModelLoader
|
||||
|
||||
from invokeai.backend.raw_model import RawModel
|
||||
|
||||
|
||||
class SpandrelImageToImageModel(RawModel):
|
||||
"""A wrapper for a Spandrel Image-to-Image model.
|
||||
|
||||
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
||||
"""
|
||||
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
@staticmethod
|
||||
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
|
||||
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
|
||||
|
||||
Args:
|
||||
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
image_np = np.array(image)
|
||||
# (H, W, C) -> (C, H, W)
|
||||
image_np = np.transpose(image_np, (2, 0, 1))
|
||||
image_np = image_np / 255
|
||||
image_tensor = torch.from_numpy(image_np).float()
|
||||
# (C, H, W) -> (N, C, H, W)
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
return image_tensor
|
||||
|
||||
@staticmethod
|
||||
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
|
||||
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
|
||||
Returns:
|
||||
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
|
||||
"""
|
||||
# (N, C, H, W) -> (C, H, W)
|
||||
tensor = tensor.squeeze(0)
|
||||
# (C, H, W) -> (H, W, C)
|
||||
tensor = tensor.permute(1, 2, 0)
|
||||
tensor = tensor.clamp(0, 1)
|
||||
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
|
||||
image = Image.fromarray(tensor)
|
||||
return image
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model.
|
||||
|
||||
Args:
|
||||
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
|
||||
"""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
@classmethod
|
||||
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||
model = ModelLoader().load_from_state_dict(state_dict)
|
||||
if not isinstance(model, ImageModelDescriptor):
|
||||
raise ValueError(
|
||||
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||
"('ImageModelDescriptor')."
|
||||
)
|
||||
|
||||
return cls(spandrel_model=model)
|
||||
|
||||
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
||||
"""Check if the model supports the given dtype."""
|
||||
if dtype == torch.float16:
|
||||
return self._spandrel_model.supports_half
|
||||
elif dtype == torch.bfloat16:
|
||||
return self._spandrel_model.supports_bfloat16
|
||||
elif dtype == torch.float32:
|
||||
# All models support float32.
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||
|
||||
def get_model_type_name(self) -> str:
|
||||
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||
consistent over time.
|
||||
"""
|
||||
return str(type(self._spandrel_model.model))
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
||||
Note: The non_blocking parameter is currently ignored."""
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
|
||||
# model directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get size of the model in memory in bytes."""
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_module_size(self._spandrel_model.model)
|
@ -11,6 +11,7 @@ import {
|
||||
useLoRAModels,
|
||||
useMainModels,
|
||||
useRefinerModels,
|
||||
useSpandrelImageToImageModels,
|
||||
useT2IAdapterModels,
|
||||
useVAEModels,
|
||||
} from 'services/api/hooks/modelsByType';
|
||||
@ -71,6 +72,13 @@ const ModelList = () => {
|
||||
[vaeModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] =
|
||||
useSpandrelImageToImageModels();
|
||||
const filteredSpandrelImageToImageModels = useMemo(
|
||||
() => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType),
|
||||
[spandrelImageToImageModels, searchTerm, filteredModelType]
|
||||
);
|
||||
|
||||
const totalFilteredModels = useMemo(() => {
|
||||
return (
|
||||
filteredMainModels.length +
|
||||
@ -80,7 +88,8 @@ const ModelList = () => {
|
||||
filteredControlNetModels.length +
|
||||
filteredT2IAdapterModels.length +
|
||||
filteredIPAdapterModels.length +
|
||||
filteredVAEModels.length
|
||||
filteredVAEModels.length +
|
||||
filteredSpandrelImageToImageModels.length
|
||||
);
|
||||
}, [
|
||||
filteredControlNetModels.length,
|
||||
@ -91,6 +100,7 @@ const ModelList = () => {
|
||||
filteredRefinerModels.length,
|
||||
filteredT2IAdapterModels.length,
|
||||
filteredVAEModels.length,
|
||||
filteredSpandrelImageToImageModels.length,
|
||||
]);
|
||||
|
||||
return (
|
||||
@ -143,6 +153,17 @@ const ModelList = () => {
|
||||
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
|
||||
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
|
||||
)}
|
||||
{/* Spandrel Image to Image List */}
|
||||
{isLoadingSpandrelImageToImageModels && (
|
||||
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
|
||||
)}
|
||||
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
|
||||
<ModelListWrapper
|
||||
title="Image-to-Image"
|
||||
modelList={filteredSpandrelImageToImageModels}
|
||||
key="spandrel-image-to-image"
|
||||
/>
|
||||
)}
|
||||
{totalFilteredModels === 0 && (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text>{t('modelManager.noMatchingModels')}</Text>
|
||||
|
@ -21,6 +21,7 @@ export const ModelTypeFilter = () => {
|
||||
t2i_adapter: t('common.t2iAdapter'),
|
||||
ip_adapter: t('common.ipAdapter'),
|
||||
clip_vision: 'Clip Vision',
|
||||
spandrel_image_to_image: 'Image-to-Image',
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
|
@ -32,6 +32,8 @@ import {
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
@ -125,6 +128,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
isSpandrelImageToImageModelFieldInputInstance(fieldInstance) &&
|
||||
isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)
|
||||
) {
|
||||
return (
|
||||
<SpandrelImageToImageModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
@ -0,0 +1,55 @@
|
||||
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
SpandrelImageToImageModelFieldInputInstance,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
|
||||
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const SpandrelImageToImageModelFieldInputComponent = (
|
||||
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(value: SpandrelImageToImageModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldSpandrelImageToImageModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const { options, value, onChange } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
selectedModel: field.value,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<Tooltip label={value?.description}>
|
||||
<FormControl className="nowheel nodrag" isInvalid={!value}>
|
||||
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
|
||||
</FormControl>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SpandrelImageToImageModelFieldInputComponent);
|
@ -19,6 +19,7 @@ import type {
|
||||
ModelIdentifierFieldValue,
|
||||
SchedulerFieldValue,
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
@ -39,6 +40,7 @@ import {
|
||||
zModelIdentifierFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
@ -333,6 +335,12 @@ export const nodesSlice = createSlice({
|
||||
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
|
||||
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
|
||||
},
|
||||
fieldSpandrelImageToImageModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<SpandrelImageToImageModelFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
|
||||
},
|
||||
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
|
||||
fieldValueReducer(state, action, zEnumFieldValue);
|
||||
},
|
||||
@ -384,6 +392,7 @@ export const {
|
||||
fieldImageValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldModelIdentifierValueChanged,
|
||||
|
@ -66,6 +66,7 @@ const zModelType = z.enum([
|
||||
'embedding',
|
||||
'onnx',
|
||||
'clip_vision',
|
||||
'spandrel_image_to_image',
|
||||
]);
|
||||
const zSubModelType = z.enum([
|
||||
'unet',
|
||||
|
@ -38,6 +38,7 @@ export const MODEL_TYPES = [
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SpandrelImageToImageModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
|
@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('T2IAdapterModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SpandrelImageToImageModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
|
||||
zControlNetModelFieldType,
|
||||
zIPAdapterModelFieldType,
|
||||
zT2IAdapterModelFieldType,
|
||||
zSpandrelImageToImageModelFieldType,
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
]);
|
||||
@ -581,6 +586,33 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
|
||||
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
|
||||
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSpandrelImageToImageModelFieldValue,
|
||||
});
|
||||
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSpandrelImageToImageModelFieldType,
|
||||
});
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SchedulerField
|
||||
|
||||
export const zSchedulerFieldValue = zSchedulerField.optional();
|
||||
@ -667,6 +699,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zControlNetModelFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
]);
|
||||
@ -694,6 +727,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zControlNetModelFieldInputInstance,
|
||||
zIPAdapterModelFieldInputInstance,
|
||||
zT2IAdapterModelFieldInputInstance,
|
||||
zSpandrelImageToImageModelFieldInputInstance,
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
]);
|
||||
@ -722,6 +756,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zControlNetModelFieldInputTemplate,
|
||||
zIPAdapterModelFieldInputTemplate,
|
||||
zT2IAdapterModelFieldInputTemplate,
|
||||
zSpandrelImageToImageModelFieldInputTemplate,
|
||||
zColorFieldInputTemplate,
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
@ -751,6 +786,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
zIPAdapterModelFieldOutputTemplate,
|
||||
zT2IAdapterModelFieldOutputTemplate,
|
||||
zSpandrelImageToImageModelFieldOutputTemplate,
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
]);
|
||||
|
@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
SpandrelImageToImageModelField: undefined,
|
||||
VAEModelField: undefined,
|
||||
ControlNetModelField: undefined,
|
||||
};
|
||||
|
@ -17,6 +17,7 @@ import type {
|
||||
SchedulerFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
|
||||
SpandrelImageToImageModelFieldInputTemplate
|
||||
> = ({ schemaObject, baseField, fieldType }) => {
|
||||
const template: SpandrelImageToImageModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
|
@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'ControlNetModelField',
|
||||
'IPAdapterModelField',
|
||||
'T2IAdapterModelField',
|
||||
'SpandrelImageToImageModelField',
|
||||
];
|
||||
|
||||
/**
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isSpandrelImageToImageModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTIModelConfig,
|
||||
isVAEModelConfig,
|
||||
@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
|
||||
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
|
||||
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
|
||||
export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
|
||||
export const useVAEModels = buildModelsHook(isVAEModelConfig);
|
||||
|
File diff suppressed because one or more lines are too long
@ -51,6 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
|
||||
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
|
||||
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
|
||||
export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
|
||||
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
|
||||
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
|
||||
type DiffusersModelConfig = S['MainDiffusersConfig'];
|
||||
type CheckpointModelConfig = S['MainCheckpointConfig'];
|
||||
@ -62,6 +63,7 @@ export type AnyModelConfig =
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| T2IAdapterModelConfig
|
||||
| SpandrelImageToImageModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig
|
||||
| CLIPVisionDiffusersConfig;
|
||||
@ -86,6 +88,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
|
||||
return config.type === 't2i_adapter';
|
||||
};
|
||||
|
||||
export const isSpandrelImageToImageModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is SpandrelImageToImageModelConfig => {
|
||||
return config.type === 'spandrel_image_to_image';
|
||||
};
|
||||
|
||||
export const isControlAdapterModelConfig = (
|
||||
config: AnyModelConfig
|
||||
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
|
||||
|
@ -46,6 +46,7 @@ dependencies = [
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
"spandrel==0.3.4",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torchmetrics==0.11.4",
|
||||
|
Loading…
Reference in New Issue
Block a user