mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - Begin to integrate SpandreImageToImageModel type into the model manager.
This commit is contained in:
parent
e6abea7bc5
commit
59ce9cf41c
@ -68,6 +68,7 @@ class ModelType(str, Enum):
|
|||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
CLIPVision = "clip_vision"
|
CLIPVision = "clip_vision"
|
||||||
T2IAdapter = "t2i_adapter"
|
T2IAdapter = "t2i_adapter"
|
||||||
|
SpandrelImageToImage = "spandrel_image_to_image"
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
class SubModelType(str, Enum):
|
||||||
|
@ -0,0 +1,34 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||||
|
|
||||||
|
|
||||||
|
@ModelLoaderRegistry.register(
|
||||||
|
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
|
||||||
|
)
|
||||||
|
class SpandrelImageToImageModelLoader(ModelLoader):
|
||||||
|
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
config: AnyModelConfig,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
if submodel_type is not None:
|
||||||
|
raise ValueError("Unexpected submodel requested for Spandrel model.")
|
||||||
|
|
||||||
|
model_path = Path(config.path)
|
||||||
|
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||||
|
|
||||||
|
return model
|
@ -10,6 +10,7 @@ from picklescan.scanner import scan_file_path
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
|
||||||
|
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
@ -240,6 +241,14 @@ class ModelProbe(object):
|
|||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
|
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||||
|
try:
|
||||||
|
_ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
||||||
|
return ModelType.SpandrelImageToImage
|
||||||
|
except Exception:
|
||||||
|
# TODO(ryand): Catch a more specific exception type here if we can.
|
||||||
|
pass
|
||||||
|
|
||||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -570,6 +579,11 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class SpandrelImageToImageModelProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
#######################################################
|
#######################################################
|
||||||
|
@ -1,15 +1,3 @@
|
|||||||
"""Base class for 'Raw' models.
|
|
||||||
|
|
||||||
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
|
||||||
and is used for type checking of calls to the model patcher. Its main purpose
|
|
||||||
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
|
||||||
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
|
||||||
from lora.py.
|
|
||||||
|
|
||||||
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
|
||||||
that adds additional methods and attributes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -17,7 +5,17 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class RawModel(ABC):
|
class RawModel(ABC):
|
||||||
"""Abstract base class for 'Raw' model wrappers."""
|
"""Base class for 'Raw' models.
|
||||||
|
|
||||||
|
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
|
||||||
|
and is used for type checking of calls to the model patcher. Its main purpose
|
||||||
|
is to avoid a circular import issues when lora.py tries to import BaseModelType
|
||||||
|
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
|
||||||
|
from lora.py.
|
||||||
|
|
||||||
|
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
||||||
|
that adds additional methods and attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def to(
|
def to(
|
||||||
|
63
invokeai/backend/spandrel_image_to_image_model.py
Normal file
63
invokeai/backend/spandrel_image_to_image_model.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from spandrel import ImageModelDescriptor, ModelLoader
|
||||||
|
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
|
|
||||||
|
class SpandrelImageToImageModel(RawModel):
|
||||||
|
"""A wrapper for a Spandrel Image-to-Image model.
|
||||||
|
|
||||||
|
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||||
|
self._spandrel_model = spandrel_model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_file(cls, file_path: str | Path):
|
||||||
|
model = ModelLoader().load_from_file(file_path)
|
||||||
|
if not isinstance(model, ImageModelDescriptor):
|
||||||
|
raise ValueError(
|
||||||
|
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||||
|
"('ImageModelDescriptor')."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(spandrel_model=model)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||||
|
model = ModelLoader().load_from_state_dict(state_dict)
|
||||||
|
if not isinstance(model, ImageModelDescriptor):
|
||||||
|
raise ValueError(
|
||||||
|
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
|
||||||
|
"('ImageModelDescriptor')."
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(spandrel_model=model)
|
||||||
|
|
||||||
|
def supports_dtype(self, dtype: torch.dtype) -> bool:
|
||||||
|
"""Check if the model supports the given dtype."""
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return self._spandrel_model.supports_half
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
return self._spandrel_model.supports_bfloat16
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
# All models support float32.
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
|
||||||
|
Note: The non_blocking parameter is currently ignored."""
|
||||||
|
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model
|
||||||
|
# directly if we want to apply this optimization.
|
||||||
|
self._spandrel_model.to(device=device, dtype=dtype)
|
Loading…
Reference in New Issue
Block a user