Add support for Spandrel Image-to-Image models (e.g. ESRGAN, Real-ESRGAN, Swin-IR, DAT, etc.) (#6556)

## Summary

- Add support for all
[spandrel](https://github.com/chaiNNer-org/spandrel) image-to-image
models - this is a collection of many popular super-resolution models
(e.g. ESRGAN, Real-ESRGAN, SwinIR, DAT, etc.)

Examples of supported models:

- DAT:
https://drive.google.com/drive/folders/1iBdf_-LVZuz_PAbFtuxSKd_11RL1YKxM
- SwinIR: https://github.com/JingyunLiang/SwinIR/releases
- Any ESRGAN / Real-ESRGAN model

## Related Issues

Closes #6394 

## QA Instructions

- [x] Test that unsupported models still fail the probe (i.e. no false
positive spandrel models)
- [x] Test adding a few non-spandrel model types
- [x] Test adding a handful of spandrel model types: ESRGAN,
Real-ESRGAN, SwinIR, DAT
- [x] Verify model size estimation for the model cache
- [x] Test using the spandrel models in a practical image upscaling
workflow

## Merge Plan

- [x] Get approval from @brandonrising and @maryhipp before merging -
this PR has commercial implications.
- [x] Merge #6571 and change the target branch to `main`

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-16 15:37:20 -04:00 committed by GitHub
commit 7ad32dcad2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 734 additions and 172 deletions

View File

@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField" ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField" IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField" T2IAdapterModel = "T2IAdapterModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion # endregion
# region Misc Field Types # region Misc Field Types
@ -134,6 +135,7 @@ class FieldDescriptions:
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
lora_weight = "The weight at which the LoRA is applied to each model" lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)" raw_prompt = "Raw prompt text (no parsing)"

View File

@ -0,0 +1,49 @@
import torch
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
image: ImageField = InputField(description="The input image")
image_to_image_model: ModelIdentifierField = InputField(
title="Image-to-Image Model",
description=FieldDescriptions.spandrel_image_to_image_model,
ui_type=UIType.SpandrelImageToImageModel,
)
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name)
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run inference.
image_tensor = spandrel_model.run(image_tensor)
# Convert the output tensor to a PIL image.
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

View File

@ -67,6 +67,7 @@ class ModelType(str, Enum):
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
CLIPVision = "clip_vision" CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter" T2IAdapter = "t2i_adapter"
SpandrelImageToImage = "spandrel_image_to_image"
class SubModelType(str, Enum): class SubModelType(str, Enum):
@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
def get_model_discriminator_value(v: Any) -> str: def get_model_discriminator_value(v: Any) -> str:
""" """
Computes the discriminator value for a model config. Computes the discriminator value for a model config.
@ -407,6 +419,7 @@ AnyModelConfig = Annotated[
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
], ],
Discriminator(get_model_discriminator_value), Discriminator(get_model_discriminator_value),

View File

@ -0,0 +1,45 @@
from pathlib import Path
from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
)
class SpandrelImageToImageModelLoader(ModelLoader):
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for Spandrel model.")
model_path = Path(config.path)
model = SpandrelImageToImageModel.load_from_file(model_path)
torch_dtype = self._torch_dtype
if not model.supports_dtype(torch_dtype):
self._logger.warning(
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
"model. Falling back to 'float32'."
)
torch_dtype = torch.float32
model.to(dtype=torch_dtype)
return model

View File

@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.textual_inversion import TextualInversionModelRaw
@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, CLIPTokenizer): elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0 return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)): elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
return model.calc_size() return model.calc_size()
else: else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch import safetensors.torch
import spandrel
import torch import torch
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType, SchedulerPredictionType,
) )
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
CkptType = Dict[str | int, Any] CkptType = Dict[str | int, Any]
@ -220,25 +222,47 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt) ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]: for key in [str(k) for k in ckpt.keys()]:
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
return ModelType.Main return ModelType.Main
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
return ModelType.VAE return ModelType.VAE
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): elif key.startswith(("lora_te_", "lora_unet_")):
return ModelType.LoRA return ModelType.LoRA
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
return ModelType.LoRA return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}): elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion return ModelType.TextualInversion
else:
# diffusers-ti # diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage
except spandrel.UnsupportedModelError:
pass
except RuntimeError as e:
if "No such file or directory" in str(e):
# This error is expected if the model_path does not exist (which is the case in some unit tests).
pass
else:
raise e
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod @classmethod
@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError() raise NotImplementedError()
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
######################################################## ########################################################
# classes for probing folders # classes for probing folders
####################################################### #######################################################
@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterFolderProbe(FolderProbeBase): class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json" config_file = self.model_path / "config.json"
@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,6 +1,13 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
class RawModel(ABC):
"""Base class for 'Raw' models. """Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
and is used for type checking of calls to the model patcher. Its main purpose and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
@ -10,15 +17,6 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes. that adds additional methods and attributes.
""" """
from abc import ABC, abstractmethod
from typing import Optional
import torch
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod @abstractmethod
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
pass pass

View File

@ -0,0 +1,134 @@
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.backend.raw_model import RawModel
class SpandrelImageToImageModel(RawModel):
"""A wrapper for a Spandrel Image-to-Image model.
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
"""
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
self._spandrel_model = spandrel_model
@staticmethod
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
Args:
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
Returns:
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
image_np = np.array(image)
# (H, W, C) -> (C, H, W)
image_np = np.transpose(image_np, (2, 0, 1))
image_np = image_np / 255
image_tensor = torch.from_numpy(image_np).float()
# (C, H, W) -> (N, C, H, W)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
@staticmethod
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
Args:
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
Returns:
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
"""
# (N, C, H, W) -> (C, H, W)
tensor = tensor.squeeze(0)
# (C, H, W) -> (H, W, C)
tensor = tensor.permute(1, 2, 0)
tensor = tensor.clamp(0, 1)
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Run the image-to-image model.
Args:
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
return self._spandrel_model(image_tensor)
@classmethod
def load_from_file(cls, file_path: str | Path):
model = ModelLoader().load_from_file(file_path)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
@classmethod
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
model = ModelLoader().load_from_state_dict(state_dict)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
def supports_dtype(self, dtype: torch.dtype) -> bool:
"""Check if the model supports the given dtype."""
if dtype == torch.float16:
return self._spandrel_model.supports_half
elif dtype == torch.bfloat16:
return self._spandrel_model.supports_bfloat16
elif dtype == torch.float32:
# All models support float32.
return True
else:
raise ValueError(f"Unexpected dtype '{dtype}'.")
def get_model_type_name(self) -> str:
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
consistent over time.
"""
return str(type(self._spandrel_model.model))
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
Note: The non_blocking parameter is currently ignored."""
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
# model directly if we want to apply this optimization.
self._spandrel_model.to(device=device, dtype=dtype)
@property
def device(self) -> torch.device:
"""The device of the underlying model."""
return self._spandrel_model.device
@property
def dtype(self) -> torch.dtype:
"""The dtype of the underlying model."""
return self._spandrel_model.dtype
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._spandrel_model.model)

View File

@ -11,6 +11,7 @@ import {
useLoRAModels, useLoRAModels,
useMainModels, useMainModels,
useRefinerModels, useRefinerModels,
useSpandrelImageToImageModels,
useT2IAdapterModels, useT2IAdapterModels,
useVAEModels, useVAEModels,
} from 'services/api/hooks/modelsByType'; } from 'services/api/hooks/modelsByType';
@ -71,6 +72,13 @@ const ModelList = () => {
[vaeModels, searchTerm, filteredModelType] [vaeModels, searchTerm, filteredModelType]
); );
const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] =
useSpandrelImageToImageModels();
const filteredSpandrelImageToImageModels = useMemo(
() => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType),
[spandrelImageToImageModels, searchTerm, filteredModelType]
);
const totalFilteredModels = useMemo(() => { const totalFilteredModels = useMemo(() => {
return ( return (
filteredMainModels.length + filteredMainModels.length +
@ -80,7 +88,8 @@ const ModelList = () => {
filteredControlNetModels.length + filteredControlNetModels.length +
filteredT2IAdapterModels.length + filteredT2IAdapterModels.length +
filteredIPAdapterModels.length + filteredIPAdapterModels.length +
filteredVAEModels.length filteredVAEModels.length +
filteredSpandrelImageToImageModels.length
); );
}, [ }, [
filteredControlNetModels.length, filteredControlNetModels.length,
@ -91,6 +100,7 @@ const ModelList = () => {
filteredRefinerModels.length, filteredRefinerModels.length,
filteredT2IAdapterModels.length, filteredT2IAdapterModels.length,
filteredVAEModels.length, filteredVAEModels.length,
filteredSpandrelImageToImageModels.length,
]); ]);
return ( return (
@ -143,6 +153,17 @@ const ModelList = () => {
{!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && (
<ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" /> <ModelListWrapper title={t('common.t2iAdapter')} modelList={filteredT2IAdapterModels} key="t2i-adapters" />
)} )}
{/* Spandrel Image to Image List */}
{isLoadingSpandrelImageToImageModels && (
<FetchingModelsLoader loadingMessage="Loading Image-to-Image Models..." />
)}
{!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && (
<ModelListWrapper
title="Image-to-Image"
modelList={filteredSpandrelImageToImageModels}
key="spandrel-image-to-image"
/>
)}
{totalFilteredModels === 0 && ( {totalFilteredModels === 0 && (
<Flex w="full" h="full" alignItems="center" justifyContent="center"> <Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>{t('modelManager.noMatchingModels')}</Text> <Text>{t('modelManager.noMatchingModels')}</Text>

View File

@ -21,6 +21,7 @@ export const ModelTypeFilter = () => {
t2i_adapter: t('common.t2iAdapter'), t2i_adapter: t('common.t2iAdapter'),
ip_adapter: t('common.ipAdapter'), ip_adapter: t('common.ipAdapter'),
clip_vision: 'Clip Vision', clip_vision: 'Clip Vision',
spandrel_image_to_image: 'Image-to-Image',
}), }),
[t] [t]
); );

View File

@ -32,6 +32,8 @@ import {
isSDXLMainModelFieldInputTemplate, isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance, isSDXLRefinerModelFieldInputInstance,
isSDXLRefinerModelFieldInputTemplate, isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldInputInstance, isStringFieldInputInstance,
isStringFieldInputTemplate, isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance, isT2IAdapterModelFieldInputInstance,
@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
@ -125,6 +128,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) { if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }
if (
isSpandrelImageToImageModelFieldInputInstance(fieldInstance) &&
isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate)
) {
return (
<SpandrelImageToImageModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />; return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
} }

View File

@ -0,0 +1,55 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice';
import type {
SpandrelImageToImageModelFieldInputInstance,
SpandrelImageToImageModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType';
import type { SpandrelImageToImageModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
const SpandrelImageToImageModelFieldInputComponent = (
props: FieldComponentProps<SpandrelImageToImageModelFieldInputInstance, SpandrelImageToImageModelFieldInputTemplate>
) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels();
const _onChange = useCallback(
(value: SpandrelImageToImageModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldSpandrelImageToImageModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
selectedModel: field.value,
isLoading,
});
return (
<Tooltip label={value?.description}>
<FormControl className="nowheel nodrag" isInvalid={!value}>
<Combobox value={value} placeholder="Pick one" options={options} onChange={onChange} />
</FormControl>
</Tooltip>
);
};
export default memo(SpandrelImageToImageModelFieldInputComponent);

View File

@ -19,6 +19,7 @@ import type {
ModelIdentifierFieldValue, ModelIdentifierFieldValue,
SchedulerFieldValue, SchedulerFieldValue,
SDXLRefinerModelFieldValue, SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue, StatefulFieldValue,
StringFieldValue, StringFieldValue,
T2IAdapterModelFieldValue, T2IAdapterModelFieldValue,
@ -39,6 +40,7 @@ import {
zModelIdentifierFieldValue, zModelIdentifierFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
zSDXLRefinerModelFieldValue, zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue, zStatefulFieldValue,
zStringFieldValue, zStringFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
@ -333,6 +335,12 @@ export const nodesSlice = createSlice({
fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => { fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction<T2IAdapterModelFieldValue>) => {
fieldValueReducer(state, action, zT2IAdapterModelFieldValue); fieldValueReducer(state, action, zT2IAdapterModelFieldValue);
}, },
fieldSpandrelImageToImageModelValueChanged: (
state,
action: FieldValueAction<SpandrelImageToImageModelFieldValue>
) => {
fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue);
},
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => { fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
fieldValueReducer(state, action, zEnumFieldValue); fieldValueReducer(state, action, zEnumFieldValue);
}, },
@ -384,6 +392,7 @@ export const {
fieldImageValueChanged, fieldImageValueChanged,
fieldIPAdapterModelValueChanged, fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged, fieldT2IAdapterModelValueChanged,
fieldSpandrelImageToImageModelValueChanged,
fieldLabelChanged, fieldLabelChanged,
fieldLoRAModelValueChanged, fieldLoRAModelValueChanged,
fieldModelIdentifierValueChanged, fieldModelIdentifierValueChanged,

View File

@ -66,6 +66,7 @@ const zModelType = z.enum([
'embedding', 'embedding',
'onnx', 'onnx',
'clip_vision', 'clip_vision',
'spandrel_image_to_image',
]); ]);
const zSubModelType = z.enum([ const zSubModelType = z.enum([
'unet', 'unet',

View File

@ -38,6 +38,7 @@ export const MODEL_TYPES = [
'VAEField', 'VAEField',
'CLIPField', 'CLIPField',
'T2IAdapterModelField', 'T2IAdapterModelField',
'SpandrelImageToImageModelField',
]; ];
/** /**
@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500', MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500', SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500', SDXLRefinerModelField: 'teal.500',
SpandrelImageToImageModelField: 'teal.500',
StringField: 'yellow.500', StringField: 'yellow.500',
T2IAdapterField: 'teal.500', T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500', T2IAdapterModelField: 'teal.500',

View File

@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({
name: z.literal('T2IAdapterModelField'), name: z.literal('T2IAdapterModelField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
}); });
const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({
name: z.literal('SpandrelImageToImageModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSchedulerFieldType = zFieldTypeBase.extend({ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'), name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(), originalType: zStatelessFieldType.optional(),
@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([
zControlNetModelFieldType, zControlNetModelFieldType,
zIPAdapterModelFieldType, zIPAdapterModelFieldType,
zT2IAdapterModelFieldType, zT2IAdapterModelFieldType,
zSpandrelImageToImageModelFieldType,
zColorFieldType, zColorFieldType,
zSchedulerFieldType, zSchedulerFieldType,
]); ]);
@ -581,6 +586,33 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda
zT2IAdapterModelFieldInputTemplate.safeParse(val).success; zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
// #endregion // #endregion
// #region SpandrelModelToModelField
export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional();
const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
originalType: zFieldType.optional(),
default: zSpandrelImageToImageModelFieldValue,
});
const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSpandrelImageToImageModelFieldType,
});
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
export const isSpandrelImageToImageModelFieldInputInstance = (
val: unknown
): val is SpandrelImageToImageModelFieldInputInstance =>
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
export const isSpandrelImageToImageModelFieldInputTemplate = (
val: unknown
): val is SpandrelImageToImageModelFieldInputTemplate =>
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SchedulerField // #region SchedulerField
export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldValue = zSchedulerField.optional();
@ -667,6 +699,7 @@ export const zStatefulFieldValue = z.union([
zControlNetModelFieldValue, zControlNetModelFieldValue,
zIPAdapterModelFieldValue, zIPAdapterModelFieldValue,
zT2IAdapterModelFieldValue, zT2IAdapterModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zColorFieldValue, zColorFieldValue,
zSchedulerFieldValue, zSchedulerFieldValue,
]); ]);
@ -694,6 +727,7 @@ const zStatefulFieldInputInstance = z.union([
zControlNetModelFieldInputInstance, zControlNetModelFieldInputInstance,
zIPAdapterModelFieldInputInstance, zIPAdapterModelFieldInputInstance,
zT2IAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance,
zSpandrelImageToImageModelFieldInputInstance,
zColorFieldInputInstance, zColorFieldInputInstance,
zSchedulerFieldInputInstance, zSchedulerFieldInputInstance,
]); ]);
@ -722,6 +756,7 @@ const zStatefulFieldInputTemplate = z.union([
zControlNetModelFieldInputTemplate, zControlNetModelFieldInputTemplate,
zIPAdapterModelFieldInputTemplate, zIPAdapterModelFieldInputTemplate,
zT2IAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate,
zSpandrelImageToImageModelFieldInputTemplate,
zColorFieldInputTemplate, zColorFieldInputTemplate,
zSchedulerFieldInputTemplate, zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate, zStatelessFieldInputTemplate,
@ -751,6 +786,7 @@ const zStatefulFieldOutputTemplate = z.union([
zControlNetModelFieldOutputTemplate, zControlNetModelFieldOutputTemplate,
zIPAdapterModelFieldOutputTemplate, zIPAdapterModelFieldOutputTemplate,
zT2IAdapterModelFieldOutputTemplate, zT2IAdapterModelFieldOutputTemplate,
zSpandrelImageToImageModelFieldOutputTemplate,
zColorFieldOutputTemplate, zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate, zSchedulerFieldOutputTemplate,
]); ]);

View File

@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
SDXLRefinerModelField: undefined, SDXLRefinerModelField: undefined,
StringField: '', StringField: '',
T2IAdapterModelField: undefined, T2IAdapterModelField: undefined,
SpandrelImageToImageModelField: undefined,
VAEModelField: undefined, VAEModelField: undefined,
ControlNetModelField: undefined, ControlNetModelField: undefined,
}; };

View File

@ -17,6 +17,7 @@ import type {
SchedulerFieldInputTemplate, SchedulerFieldInputTemplate,
SDXLMainModelFieldInputTemplate, SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate,
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType, StatefulFieldType,
StatelessFieldInputTemplate, StatelessFieldInputTemplate,
StringFieldInputTemplate, StringFieldInputTemplate,
@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder<T2IAdapt
return template; return template;
}; };
const buildSpandrelImageToImageModelFieldInputTemplate: FieldInputTemplateBuilder<
SpandrelImageToImageModelFieldInputTemplate
> = ({ schemaObject, baseField, fieldType }) => {
const template: SpandrelImageToImageModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder<BoardFieldInputTemplate> = ({
schemaObject, schemaObject,
baseField, baseField,
@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate, SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate, StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate, VAEModelField: buildVAEModelFieldInputTemplate,
} as const; } as const;

View File

@ -35,6 +35,7 @@ const MODEL_FIELD_TYPES = [
'ControlNetModelField', 'ControlNetModelField',
'IPAdapterModelField', 'IPAdapterModelField',
'T2IAdapterModelField', 'T2IAdapterModelField',
'SpandrelImageToImageModelField',
]; ];
/** /**

View File

@ -11,6 +11,7 @@ import {
isNonSDXLMainModelConfig, isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig, isRefinerMainModelModelConfig,
isSDXLMainModelModelConfig, isSDXLMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isT2IAdapterModelConfig, isT2IAdapterModelConfig,
isTIModelConfig, isTIModelConfig,
isVAEModelConfig, isVAEModelConfig,
@ -39,6 +40,7 @@ export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig); export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig); export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig); export const useT2IAdapterModels = buildModelsHook(isT2IAdapterModelConfig);
export const useSpandrelImageToImageModels = buildModelsHook(isSpandrelImageToImageModelConfig);
export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig); export const useIPAdapterModels = buildModelsHook(isIPAdapterModelConfig);
export const useEmbeddingModels = buildModelsHook(isTIModelConfig); export const useEmbeddingModels = buildModelsHook(isTIModelConfig);
export const useVAEModels = buildModelsHook(isVAEModelConfig); export const useVAEModels = buildModelsHook(isVAEModelConfig);

File diff suppressed because one or more lines are too long

View File

@ -51,6 +51,7 @@ export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig']; export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig']; type DiffusersModelConfig = S['MainDiffusersConfig'];
type CheckpointModelConfig = S['MainCheckpointConfig']; type CheckpointModelConfig = S['MainCheckpointConfig'];
@ -62,6 +63,7 @@ export type AnyModelConfig =
| ControlNetModelConfig | ControlNetModelConfig
| IPAdapterModelConfig | IPAdapterModelConfig
| T2IAdapterModelConfig | T2IAdapterModelConfig
| SpandrelImageToImageModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig | MainModelConfig
| CLIPVisionDiffusersConfig; | CLIPVisionDiffusersConfig;
@ -86,6 +88,12 @@ export const isT2IAdapterModelConfig = (config: AnyModelConfig): config is T2IAd
return config.type === 't2i_adapter'; return config.type === 't2i_adapter';
}; };
export const isSpandrelImageToImageModelConfig = (
config: AnyModelConfig
): config is SpandrelImageToImageModelConfig => {
return config.type === 'spandrel_image_to_image';
};
export const isControlAdapterModelConfig = ( export const isControlAdapterModelConfig = (
config: AnyModelConfig config: AnyModelConfig
): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => { ): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => {

View File

@ -46,6 +46,7 @@ dependencies = [
"opencv-python==4.9.0.80", "opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3", "pytorch-lightning==2.1.3",
"safetensors==0.4.3", "safetensors==0.4.3",
"spandrel==0.3.4",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2", "torch==2.2.2",
"torchmetrics==0.11.4", "torchmetrics==0.11.4",