mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
@ -1,5 +1,4 @@
|
||||
"""Re-export frequently-used symbols from the Model Manager backend."""
|
||||
|
||||
from .config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -33,3 +32,42 @@ __all__ = [
|
||||
"SchedulerPredictionType",
|
||||
"SubModelType",
|
||||
]
|
||||
|
||||
########## to help populate the openapi_schema with format enums for each config ###########
|
||||
# This code is no longer necessary?
|
||||
# leave it here just in case
|
||||
#
|
||||
# import inspect
|
||||
# from enum import Enum
|
||||
# from typing import Any, Iterable, Dict, get_args, Set
|
||||
# def _expand(something: Any) -> Iterable[type]:
|
||||
# if isinstance(something, type):
|
||||
# yield something
|
||||
# else:
|
||||
# for x in get_args(something):
|
||||
# for y in _expand(x):
|
||||
# yield y
|
||||
|
||||
# def _find_format(cls: type) -> Iterable[Enum]:
|
||||
# if hasattr(inspect, "get_annotations"):
|
||||
# fields = inspect.get_annotations(cls)
|
||||
# else:
|
||||
# fields = cls.__annotations__
|
||||
# if "format" in fields:
|
||||
# for x in get_args(fields["format"]):
|
||||
# yield x
|
||||
# for parent_class in cls.__bases__:
|
||||
# for x in _find_format(parent_class):
|
||||
# yield x
|
||||
# return None
|
||||
|
||||
# def get_model_config_formats() -> Dict[str, Set[Enum]]:
|
||||
# result: Dict[str, Set[Enum]] = {}
|
||||
# for model_config in _expand(AnyModelConfig):
|
||||
# for field in _find_format(model_config):
|
||||
# if field is None:
|
||||
# continue
|
||||
# if not result.get(model_config.__qualname__):
|
||||
# result[model_config.__qualname__] = set()
|
||||
# result[model_config.__qualname__].add(field)
|
||||
# return result
|
||||
|
@ -6,12 +6,22 @@ from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import AnyModelLoader, LoadedModel
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||
for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
"ModelLoader",
|
||||
"ModelLoaderRegistryBase",
|
||||
"ModelLoaderRegistry",
|
||||
]
|
||||
|
@ -1,37 +1,22 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""
|
||||
Base class for model loading in InvokeAI.
|
||||
|
||||
Use like this:
|
||||
|
||||
loader = AnyModelLoader(...)
|
||||
loaded_model = loader.get_model('019ab39adfa1840455')
|
||||
with loaded_model as model: # context manager moves model into VRAM
|
||||
# do something with loaded_model
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -56,6 +41,14 @@ class LoadedModel:
|
||||
return self._locker.model
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||
# know about. I think the problem may be related to this class being an ABC.
|
||||
#
|
||||
# For example, GenericDiffusersLoader defines `get_hf_load_class()`, and StableDiffusionDiffusersModel attempts to
|
||||
# call it. However, the method is not defined in the ABC, so it is not guaranteed to be implemented.
|
||||
|
||||
|
||||
class ModelLoaderBase(ABC):
|
||||
"""Abstract base class for loading models into RAM/VRAM."""
|
||||
|
||||
@ -71,7 +64,7 @@ class ModelLoaderBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its confguration.
|
||||
|
||||
@ -90,106 +83,3 @@ class ModelLoaderBase(ABC):
|
||||
) -> int:
|
||||
"""Return size in bytes of the model, calculated before loading."""
|
||||
pass
|
||||
|
||||
|
||||
# TO DO: Better name?
|
||||
class AnyModelLoader:
|
||||
"""This class manages the model loaders and invokes the correct one to load a model of given base and type."""
|
||||
|
||||
# this tracks the loader subclasses
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
logger: Logger,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
convert_cache: ModelConvertCacheBase,
|
||||
):
|
||||
"""Initialize AnyModelLoader with its dependencies."""
|
||||
self._app_config = app_config
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache associated used by the loaders."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the convert cache associated used by the loaders."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
:param key: model key, as known to the config backend
|
||||
:param submodel_type: an ModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. ModelType.Vae)
|
||||
"""
|
||||
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
|
||||
return implementation(
|
||||
app_config=self._app_config,
|
||||
logger=self._logger,
|
||||
ram_cache=self._ram_cache,
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
return "-".join([base.value, type.value, format.value])
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
@ -1,13 +1,9 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Default implementation of model loading in InvokeAI."""
|
||||
|
||||
import sys
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from diffusers import ModelMixin
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import (
|
||||
@ -25,17 +21,6 @@ from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
|
||||
class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
||||
|
||||
# TO DO: The loader is not thread safe!
|
||||
class ModelLoader(ModelLoaderBase):
|
||||
"""Default implementation of ModelLoaderBase."""
|
||||
@ -137,43 +122,6 @@ class ModelLoader(ModelLoaderBase):
|
||||
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||
)
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")[0]
|
||||
return self._hf_definition_to_type(module="transformers", class_name=class_name)
|
||||
if not class_name:
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||
raise NotImplementedError
|
||||
|
@ -55,7 +55,7 @@ class MemorySnapshot:
|
||||
vram = None
|
||||
|
||||
try:
|
||||
malloc_info = LibcUtil().mallinfo2() # type: ignore
|
||||
malloc_info = LibcUtil().mallinfo2()
|
||||
except (OSError, AttributeError):
|
||||
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
|
||||
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
|
||||
|
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
122
invokeai/backend/model_manager/load/model_loader_registry.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
"""
|
||||
This module implements a system in which model loaders register the
|
||||
type, base and format of models that they know how to load.
|
||||
|
||||
Use like this:
|
||||
|
||||
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model = cls(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
VaeCheckpointConfig,
|
||||
VaeDiffusersConfig,
|
||||
)
|
||||
from . import ModelLoaderBase
|
||||
|
||||
|
||||
class ModelLoaderRegistryBase(ABC):
|
||||
"""This class allows model loaders to register their type, base and format."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""
|
||||
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||
|
||||
Parameters:
|
||||
:param config: Model configuration record, as returned by ModelRecordService
|
||||
:param submodel_type: Submodel to fetch (main models only)
|
||||
:return: tuple(loader_class, model_config, submodel_type)
|
||||
|
||||
Note that the returned model config may be different from one what passed
|
||||
in, in the event that a submodel type is provided.
|
||||
"""
|
||||
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
|
||||
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
)
|
||||
return implementation, conf2, submodel_type
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
|
||||
)
|
||||
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
|
||||
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
|
||||
submodel_type = None
|
||||
else:
|
||||
new_conf = config
|
||||
return new_conf, submodel_type
|
||||
|
||||
@staticmethod
|
||||
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||
return "-".join([base.value, type.value, format.value])
|
@ -13,13 +13,13 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
class ControlnetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
|
@ -1,24 +1,27 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for simple diffusers model loading in InvokeAI."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
from ..load_base import AnyModelLoader
|
||||
from ..load_default import ModelLoader
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
|
||||
class GenericDiffusersLoader(ModelLoader):
|
||||
"""Class to load simple diffusers models."""
|
||||
|
||||
@ -28,9 +31,60 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
model_class = self._get_hf_load_class(model_path)
|
||||
model_class = self.get_hf_load_class(model_path)
|
||||
if submodel_type is not None:
|
||||
raise Exception(f"There are no submodels in models of type {model_class}")
|
||||
variant = model_variant.value if model_variant else None
|
||||
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
"""Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load."""
|
||||
if submodel_type:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
result = self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
raise InvalidModelConfigException("Unable to decifer Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
return result
|
||||
|
||||
# TO DO: Add exception handling
|
||||
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||
if module in ["diffusers", "transformers"]:
|
||||
res_type = sys.modules[module]
|
||||
else:
|
||||
res_type = sys.modules["diffusers"].pipelines
|
||||
result: ModelMixin = getattr(res_type, class_name)
|
||||
return result
|
||||
|
||||
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||
|
||||
|
||||
class ConfigLoader(ConfigMixin):
|
||||
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Load a diffusrs ConfigMixin configuration."""
|
||||
cls.config_name = kwargs.pop("config_name")
|
||||
# Diffusers doesn't provide typing info
|
||||
return super().load_config(*args, **kwargs) # type: ignore
|
||||
|
@ -15,11 +15,10 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||
class IPAdapterInvokeAILoader(ModelLoader):
|
||||
"""Class to load IP Adapter diffusers models."""
|
||||
|
||||
|
@ -18,13 +18,13 @@ from invokeai.backend.model_manager import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
||||
class LoraLoader(ModelLoader):
|
||||
"""Class to load LoRA models."""
|
||||
|
||||
|
@ -13,13 +13,14 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(ModelLoader):
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
||||
def _load_model(
|
||||
@ -30,7 +31,7 @@ class OnnyxDiffusersModel(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
|
@ -19,13 +19,14 @@ from invokeai.backend.model_manager import (
|
||||
)
|
||||
from invokeai.backend.model_manager.config import MainCheckpointConfig
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(ModelLoader):
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
model_base_to_model_type = {
|
||||
@ -43,7 +44,7 @@ class StableDiffusionDiffusersModel(ModelLoader):
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading main pipelines.")
|
||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||
load_class = self.get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
|
@ -5,7 +5,6 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -15,12 +14,15 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
from .. import ModelLoader, ModelLoaderRegistry
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder
|
||||
)
|
||||
class TextualInversionLoader(ModelLoader):
|
||||
"""Class to load TI models."""
|
||||
|
||||
|
@ -14,14 +14,14 @@ from invokeai.backend.model_manager import (
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
|
||||
from .. import ModelLoaderRegistry
|
||||
from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
||||
class VaeLoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
|
@ -1,16 +1,16 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _no_op(*args, **kwargs):
|
||||
def _no_op(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def skip_torch_weight_init():
|
||||
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
||||
to skip weight initialization.
|
||||
def skip_torch_weight_init() -> Generator[None, None, None]:
|
||||
"""Monkey patch several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) to skip weight initialization.
|
||||
|
||||
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||
@ -18,13 +18,14 @@ def skip_torch_weight_init():
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
saved_functions = [hasattr(m, "reset_parameters") and m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
for torch_module in torch_modules:
|
||||
assert hasattr(torch_module, "reset_parameters")
|
||||
torch_module.reset_parameters = _no_op
|
||||
|
||||
yield None
|
||||
finally:
|
||||
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||
assert hasattr(torch_module, "reset_parameters")
|
||||
torch_module.reset_parameters = saved_function
|
||||
|
@ -13,7 +13,7 @@ from typing import Any, List, Optional, Set
|
||||
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers.utils import logging as dlogging
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
@ -76,7 +76,7 @@ class ModelMerger(object):
|
||||
custom_pipeline="checkpoint_merger",
|
||||
torch_dtype=dtype,
|
||||
variant=variant,
|
||||
)
|
||||
) # type: ignore
|
||||
merged_pipe = pipe.merge(
|
||||
pretrained_model_name_or_path_list=model_paths,
|
||||
alpha=alpha,
|
||||
|
@ -54,8 +54,8 @@ class LicenseRestrictions(BaseModel):
|
||||
AllowDifferentLicense: bool = Field(
|
||||
description="if true, derivatives of this model be redistributed under a different license", default=False
|
||||
)
|
||||
AllowCommercialUse: CommercialUsage = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default_factory=set
|
||||
AllowCommercialUse: Optional[CommercialUsage] = Field(
|
||||
description="Type of commercial use allowed or 'No' if no commercial use is allowed.", default=None
|
||||
)
|
||||
|
||||
|
||||
@ -139,7 +139,10 @@ class CivitaiMetadata(ModelMetadataWithFiles):
|
||||
@property
|
||||
def allow_commercial_use(self) -> bool:
|
||||
"""Return True if commercial use is allowed."""
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
if self.restrictions.AllowCommercialUse is None:
|
||||
return False
|
||||
else:
|
||||
return self.restrictions.AllowCommercialUse != CommercialUsage("None")
|
||||
|
||||
@property
|
||||
def allow_derivatives(self) -> bool:
|
||||
|
@ -8,7 +8,6 @@ import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
from invokeai.backend.util.util import SilenceWarnings
|
||||
|
||||
from .config import (
|
||||
@ -23,6 +22,7 @@ from .config import (
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from .hash import FastModelHash
|
||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||
|
||||
CkptType = Dict[str, Any]
|
||||
|
||||
@ -53,6 +53,7 @@ LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[Sched
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ProbeBase(object):
|
||||
"""Base class for probes."""
|
||||
|
||||
|
@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
models_found: Optional[Set[Path]] = Field(default=None)
|
||||
scanned_dirs: Optional[Set[Path]] = Field(default=None)
|
||||
pruned_paths: Optional[Set[Path]] = Field(default=None)
|
||||
models_found: Set[Path] = Field(default_factory=set)
|
||||
scanned_dirs: Set[Path] = Field(default_factory=set)
|
||||
pruned_paths: Set[Path] = Field(default_factory=set)
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
|
@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure):
|
||||
("keepcost", ctypes.c_size_t),
|
||||
]
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
s = ""
|
||||
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
|
||||
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
|
||||
@ -62,7 +62,7 @@ class LibcUtil:
|
||||
TODO: Improve cross-OS compatibility of this class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
|
||||
|
||||
def mallinfo2(self) -> Struct_mallinfo2:
|
||||
@ -72,4 +72,5 @@ class LibcUtil:
|
||||
"""
|
||||
mallinfo2 = self._libc.mallinfo2
|
||||
mallinfo2.restype = Struct_mallinfo2
|
||||
return mallinfo2()
|
||||
result: Struct_mallinfo2 = mallinfo2()
|
||||
return result
|
||||
|
@ -1,12 +1,15 @@
|
||||
"""Utilities for parsing model files, used mostly by probe.py"""
|
||||
|
||||
import json
|
||||
import torch
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
|
||||
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
|
||||
checkpoint = {}
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str):
|
||||
|
||||
return checkpoint
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
|
||||
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
|
||||
if str(path).endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
path_str = path.as_posix() if isinstance(path, Path) else path
|
||||
checkpoint = _fast_safetensors_reader(path_str)
|
||||
except Exception:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
def lora_token_vector_length(checkpoint: dict) -> int:
|
||||
|
||||
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||
"""
|
||||
Given a checkpoint in memory, return the lora token vector length
|
||||
|
||||
:param checkpoint: The checkpoint
|
||||
"""
|
||||
|
||||
def _get_shape_1(key: str, tensor, checkpoint) -> int:
|
||||
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
|
||||
lora_token_vector_length = None
|
||||
|
||||
if "." not in key:
|
||||
|
Reference in New Issue
Block a user