Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein
2024-02-05 21:55:11 -05:00
committed by psychedelicious
parent 0d3addc69b
commit 5745ce9c7d
18 changed files with 215 additions and 49 deletions

View File

@ -1,3 +1,4 @@
"""Init file for RamCache."""
"""Init file for ModelCache."""
_all__ = ["ModelCacheBase", "ModelCache"]

View File

@ -129,11 +129,17 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model locker object using key and optional submodel_type.
Retrieve model using key and optional submodel_type.
This may return an UnknownModelException if the model is not in the cache.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
"""
pass

View File

@ -24,6 +24,7 @@ import math
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, List, Optional
@ -55,6 +56,20 @@ GIG = 1073741824
MB = 2**20
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -94,6 +109,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self.stats = CacheStats()
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@ -158,21 +175,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
submodel_type: Optional[SubModelType] = None,
stats_name: Optional[str] = None,
) -> ModelLockerBase:
"""
Retrieve model using key and optional submodel_type.
This may return an IndexError if the model is not in the cache.
:param key: Opaque model key
:param submodel_type: Type of the submodel to fetch
:param stats_name: A human-readable id for the model for the purposes of
stats reporting.
This may raise an IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key not in self._cached_models:
if key in self._cached_models:
self.stats.hits += 1
else:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
cache_entry = self._cached_models[key]
return ModelLocker(
cache=self,
cache_entry=cache_entry,