mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Multiple refinements on loaders:
- Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error.
This commit is contained in:
committed by
psychedelicious
parent
ad2926a24c
commit
fbded1c0f2
@ -16,7 +16,6 @@ from .model_cache.model_cache_default import ModelCache
|
||||
# This registers the subclasses that implement loaders of specific model types
|
||||
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
|
||||
for module in loaders:
|
||||
print(f"module={module}")
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
||||
|
@ -22,6 +22,7 @@ from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelTy
|
||||
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -88,6 +89,7 @@ class AnyModelLoader:
|
||||
|
||||
# this tracks the loader subclasses
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -167,7 +169,7 @@ class AnyModelLoader:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
print("DEBUG: Registering class", subclass.__name__)
|
||||
cls._logger.debug(f"Registering class {subclass.__name__} to load models of type {base}/{type}/{format}")
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
if self._ram_cache.exists(config.key, submodel_type):
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase):
|
||||
# TO DO: Add exception handling
|
||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
if submodel_type:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config["_class_name"]
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config["_class_name"]
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
|
@ -1,3 +1,4 @@
|
||||
"""Init file for RamCache."""
|
||||
"""Init file for ModelCache."""
|
||||
|
||||
|
||||
_all__ = ["ModelCacheBase", "ModelCache"]
|
||||
|
@ -129,11 +129,17 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model locker object using key and optional submodel_type.
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
This may return an UnknownModelException if the model is not in the cache.
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@ -24,6 +24,7 @@ import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
@ -55,6 +56,20 @@ GIG = 1073741824
|
||||
MB = 2**20
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats(object):
|
||||
"""Collect statistics on cache performance."""
|
||||
|
||||
hits: int = 0 # cache hits
|
||||
misses: int = 0 # cache misses
|
||||
high_watermark: int = 0 # amount of cache used
|
||||
in_cache: int = 0 # number of models in cache
|
||||
cleared: int = 0 # number of models cleared to make space
|
||||
cache_size: int = 0 # total size of cache
|
||||
# {submodel_key => size}
|
||||
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelCache(ModelCacheBase[AnyModel]):
|
||||
"""Implementation of ModelCacheBase."""
|
||||
|
||||
@ -94,6 +109,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
|
||||
# used for stats collection
|
||||
self.stats = CacheStats()
|
||||
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
@ -158,21 +175,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
stats_name: Optional[str] = None,
|
||||
) -> ModelLockerBase:
|
||||
"""
|
||||
Retrieve model using key and optional submodel_type.
|
||||
|
||||
This may return an IndexError if the model is not in the cache.
|
||||
:param key: Opaque model key
|
||||
:param submodel_type: Type of the submodel to fetch
|
||||
:param stats_name: A human-readable id for the model for the purposes of
|
||||
stats reporting.
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key not in self._cached_models:
|
||||
if key in self._cached_models:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
cache_entry = self._cached_models[key]
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
|
41
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
41
invokeai/backend/model_manager/load/model_loaders/onnx.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Class for Onnx model loading in InvokeAI."""
|
||||
|
||||
# This should work the same as Stable Diffusion pipelines
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
|
||||
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
||||
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||
class OnnyxDiffusersModel(ModelLoader):
|
||||
"""Class to load onnx models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
model_path: Path,
|
||||
model_variant: Optional[ModelRepoVariant] = None,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if not submodel_type is not None:
|
||||
raise Exception("A submodel type must be provided when loading onnx pipelines.")
|
||||
load_class = self._get_hf_load_class(model_path, submodel_type)
|
||||
variant = model_variant.value if model_variant else None
|
||||
model_path = model_path / submodel_type.value
|
||||
result: AnyModel = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
variant=variant,
|
||||
) # type: ignore
|
||||
return result
|
Reference in New Issue
Block a user