diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index d788012dc7..9a33cc502e 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -68,6 +68,7 @@ class ModelType(str, Enum): IPAdapter = "ip_adapter" CLIPVision = "clip_vision" T2IAdapter = "t2i_adapter" + SpandrelImageToImage = "spandrel_image_to_image" class SubModelType(str, Enum): diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py new file mode 100644 index 0000000000..4241c21d24 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py @@ -0,0 +1,34 @@ +from pathlib import Path +from typing import Optional + +from invokeai.backend.model_manager.config import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel + + +@ModelLoaderRegistry.register( + base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint +) +class SpandrelImageToImageModelLoader(ModelLoader): + """Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor).""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("Unexpected submodel requested for Spandrel model.") + + model_path = Path(config.path) + model = SpandrelImageToImageModel.load_from_file(model_path) + + return model diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 28b42caa53..8ba63f0db5 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -10,6 +10,7 @@ from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.util.silence_warnings import SilenceWarnings from .config import ( @@ -240,6 +241,14 @@ class ModelProbe(object): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): return ModelType.TextualInversion + # Check if the model can be loaded as a SpandrelImageToImageModel. + try: + _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) + return ModelType.SpandrelImageToImage + except Exception: + # TODO(ryand): Catch a more specific exception type here if we can. + pass + raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") @classmethod @@ -570,6 +579,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase): raise NotImplementedError() +class SpandrelImageToImageModelProbe(CheckpointProbeBase): + def get_base_type(self) -> BaseModelType: + raise NotImplementedError() + + ######################################################## # classes for probing folders ####################################################### diff --git a/invokeai/backend/raw_model.py b/invokeai/backend/raw_model.py index 7bca6945d9..6cce354c45 100644 --- a/invokeai/backend/raw_model.py +++ b/invokeai/backend/raw_model.py @@ -1,15 +1,3 @@ -"""Base class for 'Raw' models. - -The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw, -and is used for type checking of calls to the model patcher. Its main purpose -is to avoid a circular import issues when lora.py tries to import BaseModelType -from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw -from lora.py. - -The term 'raw' was introduced to describe a wrapper around a torch.nn.Module -that adds additional methods and attributes. -""" - from abc import ABC, abstractmethod from typing import Optional @@ -17,7 +5,17 @@ import torch class RawModel(ABC): - """Abstract base class for 'Raw' model wrappers.""" + """Base class for 'Raw' models. + + The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc. + and is used for type checking of calls to the model patcher. Its main purpose + is to avoid a circular import issues when lora.py tries to import BaseModelType + from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw + from lora.py. + + The term 'raw' was introduced to describe a wrapper around a torch.nn.Module + that adds additional methods and attributes. + """ @abstractmethod def to( diff --git a/invokeai/backend/spandrel_image_to_image_model.py b/invokeai/backend/spandrel_image_to_image_model.py new file mode 100644 index 0000000000..270f521604 --- /dev/null +++ b/invokeai/backend/spandrel_image_to_image_model.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Any, Optional + +import torch +from spandrel import ImageModelDescriptor, ModelLoader + +from invokeai.backend.raw_model import RawModel + + +class SpandrelImageToImageModel(RawModel): + """A wrapper for a Spandrel Image-to-Image model. + + The main reason for having a wrapper class is to integrate with the type handling of RawModel. + """ + + def __init__(self, spandrel_model: ImageModelDescriptor[Any]): + self._spandrel_model = spandrel_model + + @classmethod + def load_from_file(cls, file_path: str | Path): + model = ModelLoader().load_from_file(file_path) + if not isinstance(model, ImageModelDescriptor): + raise ValueError( + f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " + "('ImageModelDescriptor')." + ) + + return cls(spandrel_model=model) + + @classmethod + def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]): + model = ModelLoader().load_from_state_dict(state_dict) + if not isinstance(model, ImageModelDescriptor): + raise ValueError( + f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported " + "('ImageModelDescriptor')." + ) + + return cls(spandrel_model=model) + + def supports_dtype(self, dtype: torch.dtype) -> bool: + """Check if the model supports the given dtype.""" + if dtype == torch.float16: + return self._spandrel_model.supports_half + elif dtype == torch.bfloat16: + return self._spandrel_model.supports_bfloat16 + elif dtype == torch.float32: + # All models support float32. + return True + else: + raise ValueError(f"Unexpected dtype '{dtype}'.") + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + non_blocking: bool = False, + ) -> None: + """Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported. + Note: The non_blocking parameter is currently ignored.""" + # TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model + # directly if we want to apply this optimization. + self._spandrel_model.to(device=device, dtype=dtype)