WIP - Begin to integrate SpandreImageToImageModel type into the model manager.

This commit is contained in:
Ryan Dick 2024-06-28 15:01:42 -04:00
parent e6abea7bc5
commit 59ce9cf41c
5 changed files with 123 additions and 13 deletions

View File

@ -68,6 +68,7 @@ class ModelType(str, Enum):
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
SpandrelImageToImage = "spandrel_image_to_image"
class SubModelType(str, Enum):

View File

@ -0,0 +1,34 @@
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
)
class SpandrelImageToImageModelLoader(ModelLoader):
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for Spandrel model.")
model_path = Path(config.path)
model = SpandrelImageToImageModel.load_from_file(model_path)
return model

View File

@ -10,6 +10,7 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
from .config import (
@ -240,6 +241,14 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
try:
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
return ModelType.SpandrelImageToImage
except Exception:
# TODO(ryand): Catch a more specific exception type here if we can.
pass
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@classmethod
@ -570,6 +579,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################

View File

@ -1,6 +1,13 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
class RawModel(ABC):
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
@ -10,15 +17,6 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
from abc import ABC, abstractmethod
from typing import Optional
import torch
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(
self,

View File

@ -0,0 +1,63 @@
from pathlib import Path
from typing import Any, Optional
import torch
from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.backend.raw_model import RawModel
class SpandrelImageToImageModel(RawModel):
"""A wrapper for a Spandrel Image-to-Image model.
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
"""
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
self._spandrel_model = spandrel_model
@classmethod
def load_from_file(cls, file_path: str | Path):
model = ModelLoader().load_from_file(file_path)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
@classmethod
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
model = ModelLoader().load_from_state_dict(state_dict)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
def supports_dtype(self, dtype: torch.dtype) -> bool:
"""Check if the model supports the given dtype."""
if dtype == torch.float16:
return self._spandrel_model.supports_half
elif dtype == torch.bfloat16:
return self._spandrel_model.supports_bfloat16
elif dtype == torch.float32:
# All models support float32.
return True
else:
raise ValueError(f"Unexpected dtype '{dtype}'.")
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
Note: The non_blocking parameter is currently ignored."""
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model
# directly if we want to apply this optimization.
self._spandrel_model.to(device=device, dtype=dtype)