# Copyright (c) 2023, Lincoln D. Stein """Model loader for InvokeAI.""" import hashlib from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union import torch from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import InvokeAILogger, Logger, choose_precision, choose_torch_device from .cache import CacheStats, ModelCache from .config import BaseModelType, ModelConfigBase, ModelType, SubModelType from .download import DownloadEventHandler, DownloadQueueBase from .install import ModelInstall, ModelInstallBase from .models import MODEL_CLASSES, InvalidModelException, ModelBase from .storage import ConfigFileVersionMismatchException, ModelConfigStore, get_config_store, migrate_models_store @dataclass class ModelInfo: """This is a context manager object that is used to intermediate access to a model.""" context: ModelCache.ModelLocker name: str base_model: BaseModelType type: Union[ModelType, SubModelType] key: str location: Union[Path, str] precision: torch.dtype _cache: Optional[ModelCache] = None def __enter__(self): """Context entry.""" return self.context.__enter__() def __exit__(self, *args, **kwargs): """Context exit.""" self.context.__exit__(*args, **kwargs) class ModelLoadBase(ABC): """Abstract base class for a model loader which works with the ModelConfigStore backend.""" @abstractmethod def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: """ Return a model given its key. Given a model key identified in the model configuration backend, return a ModelInfo object that can be used to retrieve the model. :param key: model key, as known to the config backend :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ pass @property @abstractmethod def store(self) -> ModelConfigStore: """Return the ModelConfigStore object that supports this loader.""" pass @property @abstractmethod def installer(self) -> ModelInstallBase: """Return the ModelInstallBase object that supports this loader.""" pass @property @abstractmethod def logger(self) -> Logger: """Return the current logger.""" pass @property def config(self) -> InvokeAIAppConfig: """Return the config object used by this installer.""" pass @property @abstractmethod def queue(self) -> DownloadQueueBase: """Return the download queue object used by this object.""" pass @abstractmethod def collect_cache_stats(self, cache_stats: CacheStats): """Replace cache statistics.""" pass @abstractmethod def resolve_model_path(self, path: Union[Path, str]) -> Path: """Turn a potentially relative path into an absolute one in the models_dir.""" pass @property @abstractmethod def precision(self) -> str: """Return 'float32' or 'float16'.""" pass @abstractmethod def sync_to_config(self): """Reinitialize the store to sync in-memory and in-disk versions.""" pass class ModelLoad(ModelLoadBase): """Implementation of ModelLoadBase.""" _app_config: InvokeAIAppConfig _store: ModelConfigStore _installer: ModelInstallBase _cache: ModelCache _logger: Logger _cache_keys: dict _models_file: Path def __init__(self, config: InvokeAIAppConfig, event_handlers: List[DownloadEventHandler] = []): """ Initialize ModelLoad object. :param config: The app's InvokeAIAppConfig object. """ if config.model_conf_path and config.model_conf_path.exists(): models_file = config.model_conf_path else: models_file = config.root_path / "configs/models3.yaml" try: store = get_config_store(models_file) except ConfigFileVersionMismatchException: migrate_models_store(config) store = get_config_store(models_file) if not store: raise ValueError(f"Invalid model configuration file: {models_file}") self._app_config = config self._store = store self._logger = InvokeAILogger.get_logger() self._installer = ModelInstall( store=self._store, logger=self._logger, config=self._app_config, event_handlers=event_handlers, ) self._cache_keys = dict() self._models_file = models_file device = torch.device(choose_torch_device()) device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else "" precision = choose_precision(device) if config.precision == "auto" else config.precision dtype = torch.float32 if precision == "float32" else torch.float16 self._logger.info(f"Using models database {models_file}") self._logger.info(f"Rendering device = {device} ({device_name})") self._logger.info(f"Maximum RAM cache size: {config.ram_cache_size}") self._logger.info(f"Maximum VRAM cache size: {config.vram_cache_size}") self._logger.info(f"Precision: {precision}") self._cache = ModelCache( max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, lazy_offloading=config.lazy_offload, execution_device=device, precision=dtype, sequential_offload=config.sequential_guidance, logger=self._logger, ) @property def store(self) -> ModelConfigStore: """Return the ModelConfigStore instance used by this class.""" return self._store @property def precision(self) -> str: """Return 'float32' or 'float16'.""" return self._cache.precision @property def installer(self) -> ModelInstallBase: """Return the ModelInstallBase instance used by this class.""" return self._installer @property def logger(self) -> Logger: """Return the current logger.""" return self._logger @property def config(self) -> InvokeAIAppConfig: """Return the config object used by the installer.""" return self._app_config @property def queue(self) -> DownloadQueueBase: """Return the download queue object used by this object.""" return self._installer.queue def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo: """ Get the ModelInfo corresponding to the model with key "key". Given a model key identified in the model configuration backend, return a ModelInfo object that can be used to retrieve the model. :param key: model key, as known to the config backend :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ model_config = self.store.get_model(key) # May raise a UnknownModelException if model_config.model_type == "main" and not submodel_type: raise InvalidModelException("submodel_type is required when loading a main model") submodel_type = SubModelType(submodel_type) if submodel_type else None model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) if is_submodel_override: submodel_type = None model_class = self._get_implementation(model_config.base_model, model_config.model_type) if not model_path.exists(): raise InvalidModelException(f"Files for model '{key}' not found at {model_path}") dst_convert_path = self._get_model_convert_cache_path(model_path) model_path = self.resolve_model_path( model_class.convert_if_required( model_config=model_config, output_path=dst_convert_path, ) ) model_context = self._cache.get_model( model_path=model_path, model_class=model_class, base_model=model_config.base_model, model_type=model_config.model_type, submodel=submodel_type, ) if key not in self._cache_keys: self._cache_keys[key] = set() self._cache_keys[key].add(model_context.key) return ModelInfo( context=model_context, name=model_config.name, base_model=model_config.base_model, type=submodel_type or model_config.model_type, key=model_config.key, location=model_path, precision=self._cache.precision, _cache=self._cache, ) def collect_cache_stats(self, cache_stats: CacheStats): """Save CacheStats object for stats collecting.""" self._cache.stats = cache_stats def resolve_model_path(self, path: Union[Path, str]) -> Path: """Turn a potentially relative path into an absolute one in the models_dir.""" return self._app_config.models_path / path def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: """Get the concrete implementation class for a specific model type.""" model_class = MODEL_CLASSES[base_model][model_type] return model_class def _get_model_convert_cache_path(self, model_path): return self.resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest()) def _get_model_path( self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None ) -> Tuple[Path, bool]: """Extract a model's filesystem path from its config. :return: The fully qualified Path of the module (or submodule). """ model_path = Path(model_config.path) is_submodel_override = False # Does the config explicitly override the submodel? if submodel_type is not None and hasattr(model_config, submodel_type): submodel_path = getattr(model_config, submodel_type) if submodel_path is not None and len(submodel_path) > 0: model_path = getattr(model_config, submodel_type) is_submodel_override = True model_path = self.resolve_model_path(model_path) return model_path, is_submodel_override def sync_to_config(self): """Synchronize models on disk to those in memory.""" self._store = get_config_store(self._models_file) self.installer.scan_models_directory()