BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent 5745ce9c7d
commit 78ef946e01
31 changed files with 727 additions and 496 deletions

View File

@ -1,4 +1,6 @@
"""Init file for ModelCache."""
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401
_all__ = ["ModelCacheBase", "ModelCache"]
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]

View File

@ -8,13 +8,13 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import Logger
from typing import Generic, Optional, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.config import AnyModel, SubModelType
class ModelLockerBase(ABC):
@ -65,6 +65,19 @@ class CacheRecord(Generic[T]):
return self._locks > 0
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@ -98,10 +111,22 @@ class ModelCacheBase(ABC, Generic[T]):
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None:
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> CacheStats:
"""Return collected CacheStats object."""
pass
@stats.setter
@abstractmethod
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
pass
@property
@abstractmethod
def logger(self) -> Logger:

View File

@ -24,19 +24,17 @@ import math
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
if choose_torch_device() == torch.device("mps"):
@ -56,20 +54,6 @@ GIG = 1073741824
MB = 2**20
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -110,7 +94,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self.stats = CacheStats()
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@ -140,6 +124,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the cap on cache size."""
return self._max_cache_size
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
@ -189,21 +183,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
self.stats.hits += 1
if self.stats:
self.stats.hits += 1
else:
self.stats.misses += 1
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):