diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 8374a58959..f9a483f84c 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -48,6 +48,7 @@ class UIType(str, Enum, metaclass=MetaEnum): ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" T2IAdapterModel = "T2IAdapterModelField" + SpandrelImageToImageModel = "SpandrelImageToImageModelField" # endregion # region Misc Field Types @@ -134,6 +135,7 @@ class FieldDescriptions: sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + spandrel_image_to_image_model = "Image-to-Image model" lora_weight = "The weight at which the LoRA is applied to each model" compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" raw_prompt = "Raw prompt text (no parsing)" diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py new file mode 100644 index 0000000000..76cf31480c --- /dev/null +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -0,0 +1,49 @@ +import torch + +from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + InputField, + UIType, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import ModelIdentifierField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel + + +@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0") +class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel).""" + + image: ImageField = InputField(description="The input image") + image_to_image_model: ModelIdentifierField = InputField( + title="Image-to-Image Model", + description=FieldDescriptions.spandrel_image_to_image_model, + ui_type=UIType.SpandrelImageToImageModel, + ) + + @torch.inference_mode() + def invoke(self, context: InvocationContext) -> ImageOutput: + image = context.images.get_pil(self.image.image_name) + + # Load the model. + spandrel_model_info = context.models.load(self.image_to_image_model) + + with spandrel_model_info as spandrel_model: + assert isinstance(spandrel_model, SpandrelImageToImageModel) + + # Prepare input image for inference. + image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) + image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) + + # Run inference. + image_tensor = spandrel_model.run(image_tensor) + + # Convert the output tensor to a PIL image. + pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor) + image_dto = context.images.save(image=pil_image) + return ImageOutput.build(image_dto) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index dbcd259368..f6cc5929c8 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -67,6 +67,7 @@ class ModelType(str, Enum): IPAdapter = "ip_adapter" CLIPVision = "clip_vision" T2IAdapter = "t2i_adapter" + SpandrelImageToImage = "spandrel_image_to_image" class SubModelType(str, Enum): @@ -371,6 +372,17 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase): return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}") +class SpandrelImageToImageConfig(ModelConfigBase): + """Model config for Spandrel Image to Image models.""" + + type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage + format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}") + + def get_model_discriminator_value(v: Any) -> str: """ Computes the discriminator value for a model config. @@ -407,6 +419,7 @@ AnyModelConfig = Annotated[ Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], + Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], ], Discriminator(get_model_discriminator_value), diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py new file mode 100644 index 0000000000..7a57c5cf59 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import Optional + +import torch + +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel + + +@ModelLoaderRegistry.register( + base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint +) +class SpandrelImageToImageModelLoader(ModelLoader): + """Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor).""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("Unexpected submodel requested for Spandrel model.") + + model_path = Path(config.path) + model = SpandrelImageToImageModel.load_from_file(model_path) + + torch_dtype = self._torch_dtype + if not model.supports_dtype(torch_dtype): + self._logger.warning( + f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} " + "model. Falling back to 'float32'." + ) + torch_dtype = torch.float32 + model.to(dtype=torch_dtype) + + return model diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 64fbd29a1f..f070a42965 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -15,6 +15,7 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.textual_inversion import TextualInversionModelRaw @@ -33,7 +34,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: elif isinstance(model, CLIPTokenizer): # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. return 0 - elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)): + elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)): return model.calc_size() else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index f6fb2d24bc..1929b3f4fd 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, Literal, Optional, Union import safetensors.torch +import spandrel import torch from picklescan.scanner import scan_file_path @@ -25,6 +26,7 @@ from invokeai.backend.model_manager.config import ( SchedulerPredictionType, ) from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.util.silence_warnings import SilenceWarnings CkptType = Dict[str | int, Any] @@ -220,24 +222,46 @@ class ModelProbe(object): ckpt = ckpt.get("state_dict", ckpt) for key in [str(k) for k in ckpt.keys()]: - if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}): + if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")): return ModelType.Main - elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}): + elif key.startswith(("encoder.conv_in", "decoder.conv_in")): return ModelType.VAE - elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}): + elif key.startswith(("lora_te_", "lora_unet_")): return ModelType.LoRA - elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}): + elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")): return ModelType.LoRA - elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): + elif key.startswith(("controlnet", "control_model", "input_blocks")): return ModelType.ControlNet - elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}): + elif key.startswith(("image_proj.", "ip_adapter.")): return ModelType.IPAdapter elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion - else: - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion + + # diffusers-ti + if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): + return ModelType.TextualInversion + + # Check if the model can be loaded as a SpandrelImageToImageModel. + # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(model_path) + return ModelType.SpandrelImageToImage + except spandrel.UnsupportedModelError: + pass + except RuntimeError as e: + if "No such file or directory" in str(e): + # This error is expected if the model_path does not exist (which is the case in some unit tests). + pass + else: + raise e raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") @@ -569,6 +593,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase): raise NotImplementedError() +class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase): + def get_base_type(self) -> BaseModelType: + return BaseModelType.Any + + ######################################################## # classes for probing folders ####################################################### @@ -776,6 +805,11 @@ class CLIPVisionFolderProbe(FolderProbeBase): return BaseModelType.Any +class SpandrelImageToImageFolderProbe(FolderProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + class T2IAdapterFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" @@ -805,6 +839,7 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) +ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) @@ -814,5 +849,6 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) +ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index 931804c985..23502b20cb 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -1,15 +1,3 @@ -"""Base class for 'Raw' models. - -The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, -and is used for type checking of calls to the model patcher. Its main purpose -is to avoid a circular import issues when lora.py tries to import BaseModelType -from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw -from lora.py. - -The term 'raw' was introduced to describe a wrapper around a torch.nn.Module -that adds additional methods and attributes. -""" - from abc import ABC, abstractmethod from typing import Optional @@ -17,7 +5,17 @@ import torch class RawModel(ABC): - """Abstract base class for 'Raw' model wrappers.""" + """Base class for 'Raw' models. + + The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc. + and is used for type checking of calls to the model patcher. Its main purpose + is to avoid a circular import issues when lora.py tries to import BaseModelType + from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw + from lora.py. + + The term 'raw' was introduced to describe a wrapper around a torch.nn.Module + that adds additional methods and attributes. + """ @abstractmethod def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: diff --git a/invokeai/backend/spandrel_image_to_image_model.py b/invokeai/backend/spandrel_image_to_image_model.py new file mode 100644 index 0000000000..adb78d0d71 --- /dev/null +++ b/invokeai/backend/spandrel_image_to_image_model.py @@ -0,0 +1,134 @@ +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import torch +from PIL import Image +from spandrel import ImageModelDescriptor, ModelLoader + +from invokeai.backend.raw_model import RawModel + + +class SpandrelImageToImageModel(RawModel): + """A wrapper for a Spandrel Image-to-Image model. + + The main reason for having a wrapper class is to integrate with the type handling of RawModel. + """ + + def __init__(self, spandrel_model: ImageModelDescriptor[Any]): + self._spandrel_model = spandrel_model + + @staticmethod + def pil_to_tensor(image: Image.Image) -> torch.Tensor: + """Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run(). + + Args: + image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255]. + + Returns: + torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. + """ + image_np = np.array(image) + # (H, W, C) -> (C, H, W) + image_np = np.transpose(image_np, (2, 0, 1)) + image_np = image_np / 255 + image_tensor = torch.from_numpy(image_np).float() + # (C, H, W) -> (N, C, H, W) + image_tensor = image_tensor.unsqueeze(0) + return image_tensor + + @staticmethod + def tensor_to_pil(tensor: torch.Tensor) -> Image.Image: + """Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image. + + Args: + tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. + + Returns: + Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255]. + """ + # (N, C, H, W) -> (C, H, W) + tensor = tensor.squeeze(0) + # (C, H, W) -> (H, W, C) + tensor = tensor.permute(1, 2, 0) + tensor = tensor.clamp(0, 1) + tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8) + image = Image.fromarray(tensor) + return image + + def run(self, image_tensor: torch.Tensor) -> torch.Tensor: + """Run the image-to-image model. + + Args: + image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1]. + """ + return self._spandrel_model(image_tensor) + + @classmethod + def load_from_file(cls, file_path: str | Path): + model = ModelLoader().load_from_file(file_path) + if not isinstance(model, ImageModelDescriptor): + raise ValueError( + f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " + "('ImageModelDescriptor')." + ) + + return cls(spandrel_model=model) + + @classmethod + def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]): + model = ModelLoader().load_from_state_dict(state_dict) + if not isinstance(model, ImageModelDescriptor): + raise ValueError( + f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " + "('ImageModelDescriptor')." + ) + + return cls(spandrel_model=model) + + def supports_dtype(self, dtype: torch.dtype) -> bool: + """Check if the model supports the given dtype.""" + if dtype == torch.float16: + return self._spandrel_model.supports_half + elif dtype == torch.bfloat16: + return self._spandrel_model.supports_bfloat16 + elif dtype == torch.float32: + # All models support float32. + return True + else: + raise ValueError(f"Unexpected dtype '{dtype}'.") + + def get_model_type_name(self) -> str: + """The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining + consistent over time. + """ + return str(type(self._spandrel_model.model)) + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + """Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported. + Note: The non_blocking parameter is currently ignored.""" + # TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the + # model directly if we want to apply this optimization. + self._spandrel_model.to(device=device, dtype=dtype) + + @property + def device(self) -> torch.device: + """The device of the underlying model.""" + return self._spandrel_model.device + + @property + def dtype(self) -> torch.dtype: + """The dtype of the underlying model.""" + return self._spandrel_model.dtype + + def calc_size(self) -> int: + """Get size of the model in memory in bytes.""" + # HACK(ryand): Fix this issue with circular imports. + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self._spandrel_model.model) diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx index 67e65dbfb6..b82917221e 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelList.tsx @@ -11,6 +11,7 @@ import { useLoRAModels, useMainModels, useRefinerModels, + useSpandrelImageToImageModels, useT2IAdapterModels, useVAEModels, } from 'services/api/hooks/modelsByType'; @@ -71,6 +72,13 @@ const ModelList = () => { [vaeModels, searchTerm, filteredModelType] ); + const [spandrelImageToImageModels, { isLoading: isLoadingSpandrelImageToImageModels }] = + useSpandrelImageToImageModels(); + const filteredSpandrelImageToImageModels = useMemo( + () => modelsFilter(spandrelImageToImageModels, searchTerm, filteredModelType), + [spandrelImageToImageModels, searchTerm, filteredModelType] + ); + const totalFilteredModels = useMemo(() => { return ( filteredMainModels.length + @@ -80,7 +88,8 @@ const ModelList = () => { filteredControlNetModels.length + filteredT2IAdapterModels.length + filteredIPAdapterModels.length + - filteredVAEModels.length + filteredVAEModels.length + + filteredSpandrelImageToImageModels.length ); }, [ filteredControlNetModels.length, @@ -91,6 +100,7 @@ const ModelList = () => { filteredRefinerModels.length, filteredT2IAdapterModels.length, filteredVAEModels.length, + filteredSpandrelImageToImageModels.length, ]); return ( @@ -143,6 +153,17 @@ const ModelList = () => { {!isLoadingT2IAdapterModels && filteredT2IAdapterModels.length > 0 && ( )} + {/* Spandrel Image to Image List */} + {isLoadingSpandrelImageToImageModels && ( + + )} + {!isLoadingSpandrelImageToImageModels && filteredSpandrelImageToImageModels.length > 0 && ( + + )} {totalFilteredModels === 0 && ( {t('modelManager.noMatchingModels')} diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx index 76802b36e7..1a2444870b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelTypeFilter.tsx @@ -21,6 +21,7 @@ export const ModelTypeFilter = () => { t2i_adapter: t('common.t2iAdapter'), ip_adapter: t('common.ipAdapter'), clip_vision: 'Clip Vision', + spandrel_image_to_image: 'Image-to-Image', }), [t] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 99937ceec4..d863def973 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -32,6 +32,8 @@ import { isSDXLMainModelFieldInputTemplate, isSDXLRefinerModelFieldInputInstance, isSDXLRefinerModelFieldInputTemplate, + isSpandrelImageToImageModelFieldInputInstance, + isSpandrelImageToImageModelFieldInputTemplate, isStringFieldInputInstance, isStringFieldInputTemplate, isT2IAdapterModelFieldInputInstance, @@ -54,6 +56,7 @@ import NumberFieldInputComponent from './inputs/NumberFieldInputComponent'; import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent'; import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent'; import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent'; +import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent'; import StringFieldInputComponent from './inputs/StringFieldInputComponent'; import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent'; import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent'; @@ -125,6 +128,20 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) { return ; } + + if ( + isSpandrelImageToImageModelFieldInputInstance(fieldInstance) && + isSpandrelImageToImageModelFieldInputTemplate(fieldTemplate) + ) { + return ( + + ); + } + if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx new file mode 100644 index 0000000000..ccd4eaa797 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SpandrelImageToImageModelFieldInputComponent.tsx @@ -0,0 +1,55 @@ +import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldSpandrelImageToImageModelValueChanged } from 'features/nodes/store/nodesSlice'; +import type { + SpandrelImageToImageModelFieldInputInstance, + SpandrelImageToImageModelFieldInputTemplate, +} from 'features/nodes/types/field'; +import { memo, useCallback } from 'react'; +import { useSpandrelImageToImageModels } from 'services/api/hooks/modelsByType'; +import type { SpandrelImageToImageModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +const SpandrelImageToImageModelFieldInputComponent = ( + props: FieldComponentProps +) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + + const [modelConfigs, { isLoading }] = useSpandrelImageToImageModels(); + + const _onChange = useCallback( + (value: SpandrelImageToImageModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldSpandrelImageToImageModelValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const { options, value, onChange } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + selectedModel: field.value, + isLoading, + }); + + return ( + + + + + + ); +}; + +export default memo(SpandrelImageToImageModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5ebc5de147..f9214c1572 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -19,6 +19,7 @@ import type { ModelIdentifierFieldValue, SchedulerFieldValue, SDXLRefinerModelFieldValue, + SpandrelImageToImageModelFieldValue, StatefulFieldValue, StringFieldValue, T2IAdapterModelFieldValue, @@ -39,6 +40,7 @@ import { zModelIdentifierFieldValue, zSchedulerFieldValue, zSDXLRefinerModelFieldValue, + zSpandrelImageToImageModelFieldValue, zStatefulFieldValue, zStringFieldValue, zT2IAdapterModelFieldValue, @@ -333,6 +335,12 @@ export const nodesSlice = createSlice({ fieldT2IAdapterModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zT2IAdapterModelFieldValue); }, + fieldSpandrelImageToImageModelValueChanged: ( + state, + action: FieldValueAction + ) => { + fieldValueReducer(state, action, zSpandrelImageToImageModelFieldValue); + }, fieldEnumModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zEnumFieldValue); }, @@ -384,6 +392,7 @@ export const { fieldImageValueChanged, fieldIPAdapterModelValueChanged, fieldT2IAdapterModelValueChanged, + fieldSpandrelImageToImageModelValueChanged, fieldLabelChanged, fieldLoRAModelValueChanged, fieldModelIdentifierValueChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 54e126af3a..2ea8900281 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -66,6 +66,7 @@ const zModelType = z.enum([ 'embedding', 'onnx', 'clip_vision', + 'spandrel_image_to_image', ]); const zSubModelType = z.enum([ 'unet', diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index 4ede5cd479..05697c384c 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -38,6 +38,7 @@ export const MODEL_TYPES = [ 'VAEField', 'CLIPField', 'T2IAdapterModelField', + 'SpandrelImageToImageModelField', ]; /** @@ -62,6 +63,7 @@ export const FIELD_COLORS: { [key: string]: string } = { MainModelField: 'teal.500', SDXLMainModelField: 'teal.500', SDXLRefinerModelField: 'teal.500', + SpandrelImageToImageModelField: 'teal.500', StringField: 'yellow.500', T2IAdapterField: 'teal.500', T2IAdapterModelField: 'teal.500', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index e2a84e3390..925bd40b9d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -139,6 +139,10 @@ const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ name: z.literal('T2IAdapterModelField'), originalType: zStatelessFieldType.optional(), }); +const zSpandrelImageToImageModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SpandrelImageToImageModelField'), + originalType: zStatelessFieldType.optional(), +}); const zSchedulerFieldType = zFieldTypeBase.extend({ name: z.literal('SchedulerField'), originalType: zStatelessFieldType.optional(), @@ -160,6 +164,7 @@ const zStatefulFieldType = z.union([ zControlNetModelFieldType, zIPAdapterModelFieldType, zT2IAdapterModelFieldType, + zSpandrelImageToImageModelFieldType, zColorFieldType, zSchedulerFieldType, ]); @@ -581,6 +586,33 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda zT2IAdapterModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region SpandrelModelToModelField + +export const zSpandrelImageToImageModelFieldValue = zModelIdentifierField.optional(); +const zSpandrelImageToImageModelFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zSpandrelImageToImageModelFieldValue, +}); +const zSpandrelImageToImageModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSpandrelImageToImageModelFieldType, + originalType: zFieldType.optional(), + default: zSpandrelImageToImageModelFieldValue, +}); +const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSpandrelImageToImageModelFieldType, +}); +export type SpandrelImageToImageModelFieldValue = z.infer; +export type SpandrelImageToImageModelFieldInputInstance = z.infer; +export type SpandrelImageToImageModelFieldInputTemplate = z.infer; +export const isSpandrelImageToImageModelFieldInputInstance = ( + val: unknown +): val is SpandrelImageToImageModelFieldInputInstance => + zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success; +export const isSpandrelImageToImageModelFieldInputTemplate = ( + val: unknown +): val is SpandrelImageToImageModelFieldInputTemplate => + zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success; +// #endregion + // #region SchedulerField export const zSchedulerFieldValue = zSchedulerField.optional(); @@ -667,6 +699,7 @@ export const zStatefulFieldValue = z.union([ zControlNetModelFieldValue, zIPAdapterModelFieldValue, zT2IAdapterModelFieldValue, + zSpandrelImageToImageModelFieldValue, zColorFieldValue, zSchedulerFieldValue, ]); @@ -694,6 +727,7 @@ const zStatefulFieldInputInstance = z.union([ zControlNetModelFieldInputInstance, zIPAdapterModelFieldInputInstance, zT2IAdapterModelFieldInputInstance, + zSpandrelImageToImageModelFieldInputInstance, zColorFieldInputInstance, zSchedulerFieldInputInstance, ]); @@ -722,6 +756,7 @@ const zStatefulFieldInputTemplate = z.union([ zControlNetModelFieldInputTemplate, zIPAdapterModelFieldInputTemplate, zT2IAdapterModelFieldInputTemplate, + zSpandrelImageToImageModelFieldInputTemplate, zColorFieldInputTemplate, zSchedulerFieldInputTemplate, zStatelessFieldInputTemplate, @@ -751,6 +786,7 @@ const zStatefulFieldOutputTemplate = z.union([ zControlNetModelFieldOutputTemplate, zIPAdapterModelFieldOutputTemplate, zT2IAdapterModelFieldOutputTemplate, + zSpandrelImageToImageModelFieldOutputTemplate, zColorFieldOutputTemplate, zSchedulerFieldOutputTemplate, ]); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 597779fd61..a5a2d89f03 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -18,6 +18,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = SDXLRefinerModelField: undefined, StringField: '', T2IAdapterModelField: undefined, + SpandrelImageToImageModelField: undefined, VAEModelField: undefined, ControlNetModelField: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 2b77274526..8478415cd1 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -17,6 +17,7 @@ import type { SchedulerFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, + SpandrelImageToImageModelFieldInputTemplate, StatefulFieldType, StatelessFieldInputTemplate, StringFieldInputTemplate, @@ -263,6 +264,17 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, fieldType }) => { + const template: SpandrelImageToImageModelFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; const buildBoardFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -377,6 +389,7 @@ export const TEMPLATE_BUILDER_MAP: Record { + return config.type === 'spandrel_image_to_image'; +}; + export const isControlAdapterModelConfig = ( config: AnyModelConfig ): config is ControlNetModelConfig | T2IAdapterModelConfig | IPAdapterModelConfig => { diff --git a/pyproject.toml b/pyproject.toml index a11a19071c..9953c1c1a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "opencv-python==4.9.0.80", "pytorch-lightning==2.1.3", "safetensors==0.4.3", + "spandrel==0.3.4", "timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26 "torch==2.2.2", "torchmetrics==0.11.4",