Files
InvokeAI/invokeai/backend/model_manager/loader.py
Lincoln Stein 7430d87301 loader working
2023-09-10 23:11:25 -04:00

226 lines
8.2 KiB
Python

# Copyright (c) 2023, Lincoln D. Stein
"""Model loader for InvokeAI."""
import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Union, Optional
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import choose_precision, choose_torch_device, InvokeAILogger
from .config import BaseModelType, ModelType, SubModelType, ModelConfigBase
from .install import ModelInstallBase, ModelInstall
from .storage import ModelConfigStore, ModelConfigStoreYAML, ModelConfigStoreSQL
from .cache import ModelCache, ModelLocker
from .models import InvalidModelException, ModelBase, MODEL_CLASSES
@dataclass
class ModelInfo:
"""This is a context manager object that is used to intermediate access to a model."""
context: ModelLocker
name: str
base_model: BaseModelType
type: ModelType
id: str
location: Union[Path, str]
precision: torch.dtype
_cache: Optional[ModelCache] = None
def __enter__(self):
"""Context entry."""
return self.context.__enter__()
def __exit__(self, *args, **kwargs):
"""Context exit."""
self.context.__exit__(*args, **kwargs)
class ModelLoaderBase(ABC):
"""Abstract base class for a model loader which works with the ModelConfigStore backend."""
@abstractmethod
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Return a model given its key.
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
pass
@property
@abstractmethod
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore object that supports this loader."""
pass
@property
@abstractmethod
def installer(self) -> ModelInstallBase:
"""Return the ModelInstallBase object that supports this loader."""
pass
class ModelLoader(ModelLoaderBase):
"""Implementation of ModelLoaderBase."""
_app_config: InvokeAIAppConfig
_store: ModelConfigStore
_installer: ModelInstallBase
_cache: ModelCache
_logger: InvokeAILogger
_cache_keys: dict
def __init__(
self,
config: InvokeAIAppConfig,
):
"""
Initialize ModelLoader object.
:param config: The app's InvokeAIAppConfig object.
"""
if config.model_conf_path and config.model_conf_path.exists():
models_file = config.model_conf_path
else:
models_file = config.root_path / "configs/models3.yaml"
store = (
ModelConfigStoreYAML(models_file)
if models_file.suffix == ".yaml"
else ModelConfigStoreSQL(models_file)
if models_file.suffix == ".db"
else None
)
if not store:
raise ValueError(f"Invalid model configuration file: {models_file}")
self._app_config = config
self._store = store
self._logger = InvokeAILogger.getLogger()
self._installer = ModelInstall(store=self._store, logger=self._logger, config=self._app_config)
self._cache_keys = dict()
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
precision = choose_precision(device) if config.precision == "auto" else config.precision
dtype = torch.float32 if precision == "float32" else torch.float16
self._logger.info(f"Using models database {models_file}")
self._logger.info(f"Rendering device = {device} ({device_name})")
self._logger.info(f"Maximum RAM cache size: {config.ram_cache_size}")
self._logger.info(f"Maximum VRAM cache size: {config.vram_cache_size}")
self._logger.info(f"Precision: {precision}")
self._cache = ModelCache(
max_cache_size=config.ram_cache_size,
max_vram_cache_size=config.vram_cache_size,
lazy_offloading=config.lazy_offload,
execution_device=device,
precision=dtype,
sequential_offload=config.sequential_guidance,
logger=self._logger,
)
@property
def store(self) -> ModelConfigStore:
"""Return the ModelConfigStore instance used by this class."""
return self._store
@property
def installer(self) -> ModelInstallBase:
"""Return the ModelInstallBase instance used by this class."""
return self._installer
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> ModelInfo:
"""
Get the ModelInfo corresponding to the model with key "key".
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_config = self.store.get_model(key) # May raise a UnknownModelException
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(model_config.base_model, model_config.model_type)
if not model_path.exists():
raise InvalidModelException(f"Files for model '{key}' not found at {model_path}")
dst_convert_path = self._get_model_cache_path(model_path)
model_path = model_class.convert_if_required(
base_model=model_config.base_model,
model_path=model_path.as_posix(),
output_path=dst_convert_path,
config=model_config,
)
model_context = self._cache.get_model(
model_path=model_path,
model_class=model_class,
base_model=model_config.base_model,
model_type=model_config.model_type,
submodel=SubModelType(submodel_type),
)
if key not in self._cache_keys:
self._cache_keys[key] = set()
self._cache_keys[key].add(model_context.key)
return ModelInfo(
context=model_context,
name=model_config.name,
base_model=model_config.base_model,
type=submodel_type or model_type,
id=model_config.id,
location=model_path,
precision=self._cache.precision,
_cache=self._cache,
)
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type]
return model_class
def _get_model_cache_path(self, model_path):
return self._resolve_model_path(Path(".cache") / hashlib.md5(str(model_path).encode()).hexdigest())
def _resolve_model_path(self, path: Union[Path, str]) -> Path:
"""Return relative paths based on configured models_path."""
return self._app_config.models_path / path
def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None and len(submodel_path) > 0:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self._resolve_model_path(model_path)
return model_path, is_submodel_override