WIP - Begin to integrate SpandreImageToImageModel type into the model manager.

This commit is contained in:
Ryan Dick
2024-06-28 15:01:42 -04:00
parent e6abea7bc5
commit 59ce9cf41c
5 changed files with 123 additions and 13 deletions

View File

@ -68,6 +68,7 @@ class ModelType(str, Enum):
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
SpandrelImageToImage = "spandrel_image_to_image"
class SubModelType(str, Enum):

View File

@ -0,0 +1,34 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
)
class SpandrelImageToImageModelLoader(ModelLoader):
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for Spandrel model.")
model_path = Path(config.path)
model = SpandrelImageToImageModel.load_from_file(model_path)
return model

View File

@ -10,6 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import (
@ -240,6 +241,14 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
try:
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
return ModelType.SpandrelImageToImage
except Exception:
# TODO(ryand): Catch a more specific exception type here if we can.
pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
@ -570,6 +579,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################