diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py
index 8374a58959..f9a483f84c 100644
--- a/invokeai/app/invocations/fields.py
+++ b/invokeai/app/invocations/fields.py
@@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
+ SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
# region Misc Field Types
@@ -134,6 +135,7 @@ class FieldDescriptions:
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
+ spandrel_image_to_image_model = "Image-to-Image model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"
diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py
new file mode 100644
index 0000000000..76cf31480c
--- /dev/null
+++ b/invokeai/app/invocations/spandrel_image_to_image.py
@@ -0,0 +1,49 @@
+import torch
+
+from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ ImageField,
+ InputField,
+ UIType,
+ WithBoard,
+ WithMetadata,
+)
+from invokeai.app.invocations.model import ModelIdentifierField
+from invokeai.app.invocations.primitives import ImageOutput
+from invokeai.app.services.shared.invocation_context import InvocationContext
+from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
+
+
+@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
+class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
+ """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
+
+ image: ImageField = InputField(description="The input image")
+ image_to_image_model: ModelIdentifierField = InputField(
+ title="Image-to-Image Model",
+ description=FieldDescriptions.spandrel_image_to_image_model,
+ ui_type=UIType.SpandrelImageToImageModel,
+ )
+
+ @torch.inference_mode()
+ def invoke(self, context: InvocationContext) -> ImageOutput:
+ image = context.images.get_pil(self.image.image_name)
+
+ # Load the model.
+ spandrel_model_info = context.models.load(self.image_to_image_model)
+
+ with spandrel_model_info as spandrel_model:
+ assert isinstance(spandrel_model, SpandrelImageToImageModel)
+
+ # Prepare input image for inference.
+ image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
+ image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
+
+ # Run inference.
+ image_tensor = spandrel_model.run(image_tensor)
+
+ # Convert the output tensor to a PIL image.
+ pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
+ image_dto = context.images.save(image=pil_image)
+ return ImageOutput.build(image_dto)
diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py
index dbcd259368..f6cc5929c8 100644
--- a/invokeai/backend/model_manager/config.py
+++ b/invokeai/backend/model_manager/config.py
@@ -67,6 +67,7 @@ class ModelType(str, Enum):
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
+ SpandrelImageToImage = "spandrel_image_to_image"
class SubModelType(str, Enum):
@@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
+class SpandrelImageToImageConfig(ModelConfigBase):
+ """Model config for Spandrel Image to Image models."""
+
+ type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
+ format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
+
+ @staticmethod
+ def get_tag() -> Tag:
+ return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
+
+
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
+ Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py
new file mode 100644
index 0000000000..7a57c5cf59
--- /dev/null
+++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py
@@ -0,0 +1,45 @@
+from pathlib import Path
+from typing import Optional
+
+import torch
+
+from invokeai.backend.model_manager.config import (
+ AnyModel,
+ AnyModelConfig,
+ BaseModelType,
+ ModelFormat,
+ ModelType,
+ SubModelType,
+)
+from invokeai.backend.model_manager.load.load_default import ModelLoader
+from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
+from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
+
+
+@ModelLoaderRegistry.register(
+ base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
+)
+class SpandrelImageToImageModelLoader(ModelLoader):
+ """Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
+
+ def _load_model(
+ self,
+ config: AnyModelConfig,
+ submodel_type: Optional[SubModelType] = None,
+ ) -> AnyModel:
+ if submodel_type is not None:
+ raise ValueError("Unexpected submodel requested for Spandrel model.")
+
+ model_path = Path(config.path)
+ model = SpandrelImageToImageModel.load_from_file(model_path)
+
+ torch_dtype = self._torch_dtype
+ if not model.supports_dtype(torch_dtype):
+ self._logger.warning(
+ f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
+ "model. Falling back to 'float32'."
+ )
+ torch_dtype = torch.float32
+ model.to(dtype=torch_dtype)
+
+ return model
diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py
index 64fbd29a1f..f070a42965 100644
--- a/invokeai/backend/model_manager/load/model_util.py
+++ b/invokeai/backend/model_manager/load/model_util.py
@@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
+from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
@@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
- elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
+ elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
return model.calc_size()
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py
index f6fb2d24bc..1929b3f4fd 100644
--- a/invokeai/backend/model_manager/probe.py
+++ b/invokeai/backend/model_manager/probe.py
@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
+import spandrel
import torch
from picklescan.scanner import scan_file_path
@@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
+from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
CkptType = Dict[str | int, Any]
@@ -220,24 +222,46 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
- if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
+ if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
return ModelType.Main
- elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
+ elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE
- elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
+ elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA
- elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
+ elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
return ModelType.LoRA
- elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
+ elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
- elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
+ elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
- else:
- # diffusers-ti
- if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
- return ModelType.TextualInversion
+
+ # diffusers-ti
+ if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
+ return ModelType.TextualInversion
+
+ # Check if the model can be loaded as a SpandrelImageToImageModel.
+ # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
+ try:
+ # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
+ # explored to avoid this:
+ # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
+ # device. Unfortunately, some Spandrel models perform operations during initialization that are not
+ # supported on meta tensors.
+ # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
+ # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
+ # maintain it, and the risk of false positive detections is higher.
+ SpandrelImageToImageModel.load_from_file(model_path)
+ return ModelType.SpandrelImageToImage
+ except spandrel.UnsupportedModelError:
+ pass
+ except RuntimeError as e:
+ if "No such file or directory" in str(e):
+ # This error is expected if the model_path does not exist (which is the case in some unit tests).
+ pass
+ else:
+ raise e
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
+class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
+ def get_base_type(self) -> BaseModelType:
+ return BaseModelType.Any
+
+
########################################################
# classes for probing folders
#######################################################
@@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
+class SpandrelImageToImageFolderProbe(FolderProbeBase):
+ def get_base_type(self) -> BaseModelType:
+ raise NotImplementedError()
+
+
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
+ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
+ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py
index 931804c985..23502b20cb 100644
--- a/invokeai/backend/raw_model.py
+++ b/invokeai/backend/raw_model.py
@@ -1,15 +1,3 @@
-"""Base class for 'Raw' models.
-
-The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
-and is used for type checking of calls to the model patcher. Its main purpose
-is to avoid a circular import issues when lora.py tries to import BaseModelType
-from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
-from lora.py.
-
-The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
-that adds additional methods and attributes.
-"""
-
from abc import ABC, abstractmethod
from typing import Optional
@@ -17,7 +5,17 @@ import torch
class RawModel(ABC):
- """Abstract base class for 'Raw' model wrappers."""
+ """Base class for 'Raw' models.
+
+ The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
+ and is used for type checking of calls to the model patcher. Its main purpose
+ is to avoid a circular import issues when lora.py tries to import BaseModelType
+ from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
+ from lora.py.
+
+ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
+ that adds additional methods and attributes.
+ """
@abstractmethod
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
diff --git a/invokeai/backend/spandrel_image_to_image_model.py b/invokeai/backend/spandrel_image_to_image_model.py
new file mode 100644
index 0000000000..adb78d0d71
--- /dev/null
+++ b/invokeai/backend/spandrel_image_to_image_model.py
@@ -0,0 +1,134 @@
+from pathlib import Path
+from typing import Any, Optional
+
+import numpy as np
+import torch
+from PIL import Image
+from spandrel import ImageModelDescriptor, ModelLoader
+
+from invokeai.backend.raw_model import RawModel
+
+
+class SpandrelImageToImageModel(RawModel):
+ """A wrapper for a Spandrel Image-to-Image model.
+
+ The main reason for having a wrapper class is to integrate with the type handling of RawModel.
+ """
+
+ def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
+ self._spandrel_model = spandrel_model
+
+ @staticmethod
+ def pil_to_tensor(image: Image.Image) -> torch.Tensor:
+ """Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
+
+ Args:
+ image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
+
+ Returns:
+ torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
+ """
+ image_np = np.array(image)
+ # (H, W, C) -> (C, H, W)
+ image_np = np.transpose(image_np, (2, 0, 1))
+ image_np = image_np / 255
+ image_tensor = torch.from_numpy(image_np).float()
+ # (C, H, W) -> (N, C, H, W)
+ image_tensor = image_tensor.unsqueeze(0)
+ return image_tensor
+
+ @staticmethod
+ def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
+ """Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
+
+ Args:
+ tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
+
+ Returns:
+ Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
+ """
+ # (N, C, H, W) -> (C, H, W)
+ tensor = tensor.squeeze(0)
+ # (C, H, W) -> (H, W, C)
+ tensor = tensor.permute(1, 2, 0)
+ tensor = tensor.clamp(0, 1)
+ tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
+ image = Image.fromarray(tensor)
+ return image
+
+ def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
+ """Run the image-to-image model.
+
+ Args:
+ image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
+ """
+ return self._spandrel_model(image_tensor)
+
+ @classmethod
+ def load_from_file(cls, file_path: str | Path):
+ model = ModelLoader().load_from_file(file_path)
+ if not isinstance(model, ImageModelDescriptor):
+ raise ValueError(
+ f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
+ "('ImageModelDescriptor')."
+ )
+
+ return cls(spandrel_model=model)
+
+ @classmethod
+ def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
+ model = ModelLoader().load_from_state_dict(state_dict)
+ if not isinstance(model, ImageModelDescriptor):
+ raise ValueError(
+ f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
+ "('ImageModelDescriptor')."
+ )
+
+ return cls(spandrel_model=model)
+
+ def supports_dtype(self, dtype: torch.dtype) -> bool:
+ """Check if the model supports the given dtype."""
+ if dtype == torch.float16:
+ return self._spandrel_model.supports_half
+ elif dtype == torch.bfloat16:
+ return self._spandrel_model.supports_bfloat16
+ elif dtype == torch.float32:
+ # All models support float32.
+ return True
+ else:
+ raise ValueError(f"Unexpected dtype '{dtype}'.")
+
+ def get_model_type_name(self) -> str:
+ """The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
+ consistent over time.
+ """
+ return str(type(self._spandrel_model.model))
+
+ def to(
+ self,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ non_blocking: bool = False,
+ ) -> None:
+ """Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
+ Note: The non_blocking parameter is currently ignored."""
+ # TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
+ # model directly if we want to apply this optimization.
+ self._spandrel_model.to(device=device, dtype=dtype)
+
+ @property
+ def device(self) -> torch.device:
+ """The device of the underlying model."""
+ return self._spandrel_model.device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """The dtype of the underlying model."""
+ return self._spandrel_model.dtype
+
+ def calc_size(self) -> int:
+ """Get size of the model in memory in bytes."""
+ # HACK(ryand): Fix this issue with circular imports.
+ from invokeai.backend.model_manager.load.model_util import calc_module_size
+
+ return calc_module_size(self._spandrel_model.model)
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
index 67e65dbfb6..b82917221e 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx
@@ -11,6 +11,7 @@ import {
useLoRAModels,
useMainModels,
useRefinerModels,
+ useSpandrelImageToImageModels,
useT2IAdapterModels,
useVAEModels,
} from 'services/api/hooks/modelsByType';
@@ -71,6 +72,13 @@ const ModelList = () => {
[vaeModels, searchTerm, filteredModelType]
);
+ const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] =
+ useSpandrelImageToImageModels();
+ const filteredSpandrelImageToImageModels = useMemo(
+ () => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType),
+ [spandrelImageToImageModels, searchTerm, filteredModelType]
+ );
+
const totalFilteredModels = useMemo(() => {
return (
filteredMainModels.length +
@@ -80,7 +88,8 @@ const ModelList = () => {
filteredControlNetModels.length +
filteredT2IAdapterModels.length +
filteredIPAdapterModels.length +
- filteredVAEModels.length
+ filteredVAEModels.length +
+ filteredSpandrelImageToImageModels.length
);
}, [
filteredControlNetModels.length,
@@ -91,6 +100,7 @@ const ModelList = () => {
filteredRefinerModels.length,
filteredT2IAdapterModels.length,
filteredVAEModels.length,
+ filteredSpandrelImageToImageModels.length,
]);
return (
@@ -143,6 +153,17 @@ const ModelList = () => {
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
)}
+ {/* Spandrel Image to Image List */}
+ {isLoadingSpandrelImageToImageModels && (
+
+ )}
+ {!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
+
+ )}
{totalFilteredModels === 0 && (
{t('modelManager.noMatchingModels')}
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx
index 76802b36e7..1a2444870b 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx
@@ -21,6 +21,7 @@ export const ModelTypeFilter = () => {
t2i_adapter: t('common.t2iAdapter'),
ip_adapter: t('common.ipAdapter'),
clip_vision: 'Clip Vision',
+ spandrel_image_to_image: 'Image-to-Image',
}),
[t]
);
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
index 99937ceec4..d863def973 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx
@@ -32,6 +32,8 @@ import {
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate,
+ isSpandrelImageToImageModelFieldInputInstance,
+ isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
@@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
+import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@@ -125,6 +128,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return ;
}
+
+ if (
+ isSpandrelImageToImageModelFieldInputInstance(fieldInstance) &&
+ isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)
+ ) {
+ return (
+
+ );
+ }
+
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return ;
}
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx
new file mode 100644
index 0000000000..ccd4eaa797
--- /dev/null
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx
@@ -0,0 +1,55 @@
+import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
+import { useAppDispatch } from 'app/store/storeHooks';
+import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
+import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
+import type {
+ SpandrelImageToImageModelFieldInputInstance,
+ SpandrelImageToImageModelFieldInputTemplate,
+} from 'features/nodes/types/field';
+import { memo, useCallback } from 'react';
+import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
+import type { SpandrelImageToImageModelConfig } from 'services/api/types';
+
+import type { FieldComponentProps } from './types';
+
+const SpandrelImageToImageModelFieldInputComponent = (
+ props: FieldComponentProps
+) => {
+ const { nodeId, field } = props;
+ const dispatch = useAppDispatch();
+
+ const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
+
+ const _onChange = useCallback(
+ (value: SpandrelImageToImageModelConfig | null) => {
+ if (!value) {
+ return;
+ }
+ dispatch(
+ fieldSpandrelImageToImageModelValueChanged({
+ nodeId,
+ fieldName: field.name,
+ value,
+ })
+ );
+ },
+ [dispatch, field.name, nodeId]
+ );
+
+ const { options, value, onChange } = useGroupedModelCombobox({
+ modelConfigs,
+ onChange: _onChange,
+ selectedModel: field.value,
+ isLoading,
+ });
+
+ return (
+
+
+
+
+
+ );
+};
+
+export default memo(SpandrelImageToImageModelFieldInputComponent);
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 5ebc5de147..f9214c1572 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -19,6 +19,7 @@ import type {
ModelIdentifierFieldValue,
SchedulerFieldValue,
SDXLRefinerModelFieldValue,
+ SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldValue,
T2IAdapterModelFieldValue,
@@ -39,6 +40,7 @@ import {
zModelIdentifierFieldValue,
zSchedulerFieldValue,
zSDXLRefinerModelFieldValue,
+ zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
@@ -333,6 +335,12 @@ export const nodesSlice = createSlice({
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
},
+ fieldSpandrelImageToImageModelValueChanged: (
+ state,
+ action: FieldValueAction
+ ) => {
+ fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
+ },
fieldEnumModelValueChanged: (state, action: FieldValueAction) => {
fieldValueReducer(state, action, zEnumFieldValue);
},
@@ -384,6 +392,7 @@ export const {
fieldImageValueChanged,
fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
+ fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged,
fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged,
diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts
index 54e126af3a..2ea8900281 100644
--- a/invokeai/frontend/web/src/features/nodes/types/common.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/common.ts
@@ -66,6 +66,7 @@ const zModelType = z.enum([
'embedding',
'onnx',
'clip_vision',
+ 'spandrel_image_to_image',
]);
const zSubModelType = z.enum([
'unet',
diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts
index 4ede5cd479..05697c384c 100644
--- a/invokeai/frontend/web/src/features/nodes/types/constants.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts
@@ -38,6 +38,7 @@ export const MODEL_TYPES = [
'VAEField',
'CLIPField',
'T2IAdapterModelField',
+ 'SpandrelImageToImageModelField',
];
/**
@@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
+ SpandrelImageToImageModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts
index e2a84e3390..925bd40b9d 100644
--- a/invokeai/frontend/web/src/features/nodes/types/field.ts
+++ b/invokeai/frontend/web/src/features/nodes/types/field.ts
@@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(),
});
+const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
+ name: z.literal('SpandrelImageToImageModelField'),
+ originalType: zStatelessFieldType.optional(),
+});
const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
@@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
zControlNetModelFieldType,
zIPAdapterModelFieldType,
zT2IAdapterModelFieldType,
+ zSpandrelImageToImageModelFieldType,
zColorFieldType,
zSchedulerFieldType,
]);
@@ -581,6 +586,33 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
// #endregion
+// #region SpandrelModelToModelField
+
+export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
+const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
+ value: zSpandrelImageToImageModelFieldValue,
+});
+const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
+ type: zSpandrelImageToImageModelFieldType,
+ originalType: zFieldType.optional(),
+ default: zSpandrelImageToImageModelFieldValue,
+});
+const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
+ type: zSpandrelImageToImageModelFieldType,
+});
+export type SpandrelImageToImageModelFieldValue = z.infer;
+export type SpandrelImageToImageModelFieldInputInstance = z.infer;
+export type SpandrelImageToImageModelFieldInputTemplate = z.infer;
+export const isSpandrelImageToImageModelFieldInputInstance = (
+ val: unknown
+): val is SpandrelImageToImageModelFieldInputInstance =>
+ zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
+export const isSpandrelImageToImageModelFieldInputTemplate = (
+ val: unknown
+): val is SpandrelImageToImageModelFieldInputTemplate =>
+ zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
+// #endregion
+
// #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional();
@@ -667,6 +699,7 @@ export const zStatefulFieldValue = z.union([
zControlNetModelFieldValue,
zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue,
+ zSpandrelImageToImageModelFieldValue,
zColorFieldValue,
zSchedulerFieldValue,
]);
@@ -694,6 +727,7 @@ const zStatefulFieldInputInstance = z.union([
zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance,
+ zSpandrelImageToImageModelFieldInputInstance,
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
]);
@@ -722,6 +756,7 @@ const zStatefulFieldInputTemplate = z.union([
zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate,
+ zSpandrelImageToImageModelFieldInputTemplate,
zColorFieldInputTemplate,
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
@@ -751,6 +786,7 @@ const zStatefulFieldOutputTemplate = z.union([
zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate,
+ zSpandrelImageToImageModelFieldOutputTemplate,
zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate,
]);
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
index 597779fd61..a5a2d89f03 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts
@@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record =
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,
+ SpandrelImageToImageModelField: undefined,
VAEModelField: undefined,
ControlNetModelField: undefined,
};
diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
index 2b77274526..8478415cd1 100644
--- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts
@@ -17,6 +17,7 @@ import type {
SchedulerFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
+ SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldInputTemplate,
@@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, fieldType }) => {
+ const template: SpandrelImageToImageModelFieldInputTemplate = {
+ ...baseField,
+ type: fieldType,
+ default: schemaObject.default ?? undefined,
+ };
+
+ return template;
+};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder = ({
schemaObject,
baseField,
@@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record {
+ return config.type === 'spandrel_image_to_image';
+};
+
export const isControlAdapterModelConfig = (
config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {
diff --git a/pyproject.toml b/pyproject.toml
index a11a19071c..9953c1c1a0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,7 @@ dependencies = [
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"safetensors==0.4.3",
+ "spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2",
"torchmetrics==0.11.4",