final tidying before marking PR as ready for review

- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
This commit is contained in:
psychedelicious
2024-02-18 17:27:42 +11:00
parent 2ad0752582
commit be8b99eed5
74 changed files with 672 additions and 10362 deletions

View File

@ -6,12 +6,22 @@ from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel
from .load_base import LoadedModel, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
__all__ = [
"LoadedModel",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
"ModelLoader",
"ModelLoaderRegistryBase",
"ModelLoaderRegistry",
]

View File

@ -1,37 +1,22 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""
Base class for model loading in InvokeAI.
Use like this:
loader = AnyModelLoader(...)
loaded_model = loader.get_model('019ab39adfa1840455')
with loaded_model as model: # context manager moves model into VRAM
# do something with loaded_model
"""
import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type
from typing import Any, Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger
@dataclass
@ -56,6 +41,14 @@ class LoadedModel:
return self._locker.model
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.
#
# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to
# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented.
class ModelLoaderBase(ABC):
"""Abstract base class for loading models into RAM/VRAM."""
@ -71,7 +64,7 @@ class ModelLoaderBase(ABC):
pass
@abstractmethod
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its confguration.
@ -90,106 +83,3 @@ class ModelLoaderBase(ABC):
) -> int:
"""Return size in bytes of the model, calculated before loading."""
pass
# TO DO: Better name?
class AnyModelLoader:
"""This class manages the model loaders and invokes the correct one to load a model of given base and type."""
# this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {}
_logger: Logger = InvokeAILogger.get_logger()
def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize AnyModelLoader with its dependencies."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
return implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
return "-".join([base.value, type.value, format.value])
@classmethod
def get_implementation(
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator

View File

@ -1,13 +1,9 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Default implementation of model loading in InvokeAI."""
import sys
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from typing import Optional, Tuple
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
@ -25,17 +21,6 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
return super().load_config(*args, **kwargs) # type: ignore
# TO DO: The loader is not thread safe!
class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase."""
@ -137,43 +122,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines
result: ModelMixin = getattr(res_type, class_name)
return result
# TO DO: Add exception handling
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
return self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")[0]
return self._hf_definition_to_type(module="transformers", class_name=class_name)
if not class_name:
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
raise NotImplementedError

View File

@ -55,7 +55,7 @@ class MemorySnapshot:
vram = None
try:
malloc_info = LibcUtil().mallinfo2() # type: ignore
malloc_info = LibcUtil().mallinfo2()
except (OSError, AttributeError):
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)

View File

@ -0,0 +1,122 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
import hashlib
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional, Tuple, Type
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
VaeCheckpointConfig,
VaeDiffusersConfig,
)
from . import ModelLoaderBase
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
@classmethod
@abstractmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
@classmethod
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
Parameters:
:param config: Model configuration record, as returned by ModelRecordService
:param submodel_type: Submodel to fetch (main models only)
:return: tuple(loader_class, model_config, submodel_type)
Note that the returned model config may be different from one what passed
in, in the event that a submodel type is provided.
"""
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@staticmethod
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
return "-".join([base.value, type.value, format.value])

View File

@ -13,13 +13,13 @@ from invokeai.backend.model_manager import (
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""

View File

@ -1,24 +1,27 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for simple diffusers model loading in InvokeAI."""
import sys
from pathlib import Path
from typing import Optional
from typing import Any, Dict, Optional
from diffusers import ConfigMixin, ModelMixin
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
InvalidModelConfigException,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from ..load_base import AnyModelLoader
from ..load_default import ModelLoader
from .. import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
@ -28,9 +31,60 @@ class GenericDiffusersLoader(ModelLoader):
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
model_class = self.get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result
# TO DO: Add exception handling
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
if submodel_type:
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
result = self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")
assert class_name is not None
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name:
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
return result
# TO DO: Add exception handling
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
if module in ["diffusers", "transformers"]:
res_type = sys.modules[module]
else:
res_type = sys.modules["diffusers"].pipelines
result: ModelMixin = getattr(res_type, class_name)
return result
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""
@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
"""Load a diffusrs ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
# Diffusers doesn't provide typing info
return super().load_config(*args, **kwargs) # type: ignore

View File

@ -15,11 +15,10 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""

View File

@ -18,13 +18,13 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from .. import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""

View File

@ -13,13 +13,14 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(ModelLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
class OnnyxDiffusersModel(GenericDiffusersLoader):
"""Class to load onnx models."""
def _load_model(
@ -30,7 +31,7 @@ class OnnyxDiffusersModel(ModelLoader):
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading onnx pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type)
load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(

View File

@ -19,13 +19,14 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(ModelLoader):
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
@ -43,7 +44,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type)
load_class = self.get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(

View File

@ -5,7 +5,6 @@
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -15,12 +14,15 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from .. import ModelLoader, ModelLoaderRegistry
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
)
class TextualInversionLoader(ModelLoader):
"""Class to load TI models."""

View File

@ -14,14 +14,14 @@ from invokeai.backend.model_manager import (
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""

View File

@ -1,16 +1,16 @@
from contextlib import contextmanager
from typing import Any, Generator
import torch
def _no_op(*args, **kwargs):
def _no_op(*args: Any, **kwargs: Any) -> None:
pass
@contextmanager
def skip_torch_weight_init():
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
to skip weight initialization.
def skip_torch_weight_init() -> Generator[None, None, None]:
"""Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
@ -18,13 +18,14 @@ def skip_torch_weight_init():
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
assert hasattr(torch_module, "reset_parameters")
torch_module.reset_parameters = saved_function