2024-02-01 04:37:59 +00:00
|
|
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
|
|
|
"""Default implementation of model loading in InvokeAI."""
|
|
|
|
|
|
|
|
from logging import Logger
|
|
|
|
from pathlib import Path
|
2024-03-29 20:11:08 +00:00
|
|
|
from typing import Optional
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2024-02-05 04:18:00 +00:00
|
|
|
from invokeai.backend.model_manager import (
|
|
|
|
AnyModel,
|
|
|
|
AnyModelConfig,
|
|
|
|
InvalidModelConfigException,
|
|
|
|
SubModelType,
|
|
|
|
)
|
2024-06-27 21:31:28 +00:00
|
|
|
from invokeai.backend.model_manager.config import DiffusersConfigBase
|
2024-02-04 22:23:10 +00:00
|
|
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
|
2024-02-09 21:42:33 +00:00
|
|
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
2024-04-12 04:55:21 +00:00
|
|
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
2024-02-04 22:23:10 +00:00
|
|
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
2024-04-15 13:12:49 +00:00
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
|
|
|
|
# TO DO: The loader is not thread safe!
|
|
|
|
class ModelLoader(ModelLoaderBase):
|
|
|
|
"""Default implementation of ModelLoaderBase."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
app_config: InvokeAIAppConfig,
|
|
|
|
logger: Logger,
|
2024-02-04 22:23:10 +00:00
|
|
|
ram_cache: ModelCacheBase[AnyModel],
|
2024-02-01 04:37:59 +00:00
|
|
|
):
|
|
|
|
"""Initialize the loader."""
|
|
|
|
self._app_config = app_config
|
|
|
|
self._logger = logger
|
|
|
|
self._ram_cache = ram_cache
|
2024-04-15 13:12:49 +00:00
|
|
|
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
|
|
|
"""
|
|
|
|
Return a model given its configuration.
|
|
|
|
|
|
|
|
Given a model's configuration as returned by the ModelRecordConfigStore service,
|
|
|
|
return a LoadedModel object that can be used for inference.
|
|
|
|
|
|
|
|
:param model config: Configuration record for this model
|
|
|
|
:param submodel_type: an ModelType enum indicating the portion of
|
|
|
|
the model to retrieve (e.g. ModelType.Vae)
|
|
|
|
"""
|
2024-03-29 20:11:08 +00:00
|
|
|
model_path = self._get_model_path(model_config)
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
if not model_path.exists():
|
2024-02-06 03:56:32 +00:00
|
|
|
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
|
2024-02-01 04:37:59 +00:00
|
|
|
|
2024-03-29 20:11:08 +00:00
|
|
|
with skip_torch_weight_init():
|
2024-06-27 21:31:28 +00:00
|
|
|
locker = self._load_and_cache(model_config, submodel_type)
|
2024-02-16 03:41:29 +00:00
|
|
|
return LoadedModel(config=model_config, _locker=locker)
|
2024-02-01 04:37:59 +00:00
|
|
|
|
2024-03-29 20:11:08 +00:00
|
|
|
@property
|
|
|
|
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
|
|
|
"""Return the ram cache associated with this loader."""
|
|
|
|
return self._ram_cache
|
2024-02-04 03:55:09 +00:00
|
|
|
|
2024-03-29 20:11:08 +00:00
|
|
|
def _get_model_path(self, config: AnyModelConfig) -> Path:
|
|
|
|
model_base = self._app_config.models_path
|
|
|
|
return (model_base / config.path).resolve()
|
2024-02-01 04:37:59 +00:00
|
|
|
|
2024-06-27 21:31:28 +00:00
|
|
|
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
|
2024-02-06 02:55:11 +00:00
|
|
|
try:
|
2024-02-01 04:37:59 +00:00
|
|
|
return self._ram_cache.get(config.key, submodel_type)
|
2024-02-06 02:55:11 +00:00
|
|
|
except IndexError:
|
|
|
|
pass
|
2024-02-01 04:37:59 +00:00
|
|
|
|
2024-06-27 21:31:28 +00:00
|
|
|
config.path = str(self._get_model_path(config))
|
2024-08-27 15:53:50 +00:00
|
|
|
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
|
2024-06-27 21:31:28 +00:00
|
|
|
loaded_model = self._load_model(config, submodel_type)
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
self._ram_cache.put(
|
|
|
|
config.key,
|
|
|
|
submodel_type=submodel_type,
|
|
|
|
model=loaded_model,
|
|
|
|
)
|
|
|
|
|
2024-02-06 02:55:11 +00:00
|
|
|
return self._ram_cache.get(
|
|
|
|
key=config.key,
|
|
|
|
submodel_type=submodel_type,
|
|
|
|
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
|
|
|
)
|
2024-02-01 04:37:59 +00:00
|
|
|
|
|
|
|
def get_size_fs(
|
|
|
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
|
|
|
) -> int:
|
|
|
|
"""Get the size of the model on disk."""
|
|
|
|
return calc_model_size_by_fs(
|
|
|
|
model_path=model_path,
|
|
|
|
subfolder=submodel_type.value if submodel_type else None,
|
2024-03-01 02:39:06 +00:00
|
|
|
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
|
2024-02-01 04:37:59 +00:00
|
|
|
)
|
|
|
|
|
2024-02-04 22:23:10 +00:00
|
|
|
# This needs to be implemented in the subclass
|
|
|
|
def _load_model(
|
|
|
|
self,
|
2024-03-29 20:11:08 +00:00
|
|
|
config: AnyModelConfig,
|
2024-02-04 22:23:10 +00:00
|
|
|
submodel_type: Optional[SubModelType] = None,
|
|
|
|
) -> AnyModel:
|
|
|
|
raise NotImplementedError
|