model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent 5c2884569e
commit 60aa3d4893
29 changed files with 2382 additions and 237 deletions

View File

@ -0,0 +1,35 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .load_base import AnyModelLoader, LoadedModel
from .model_cache.model_cache_default import ModelCache
from .convert_cache.convert_cache_default import ModelConvertCache
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
for module in loaders:
print(f'module={module}')
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(app_config=app_config,
logger=logger,
ram_cache=ModelCache(logger=logger,
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
)

View File

@ -0,0 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']

View File

@ -0,0 +1,28 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@ -0,0 +1,64 @@
"""
Placeholder for convert cache implementation.
"""
from pathlib import Path
import shutil
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util import GIG, directory_size
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float=10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@ -16,39 +16,11 @@ from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
import torch
from diffusers import DiffusionPipeline
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.model_manager.ram_cache import ModelCacheBase
AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> None:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
@dataclass
class LoadedModel:
@ -69,7 +41,7 @@ class LoadedModel:
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self.locker.model()
return self.locker.model
class ModelLoaderBase(ABC):
@ -89,9 +61,9 @@ class ModelLoaderBase(ABC):
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its key.
Return a model given its confguration.
Given a model key identified in the model configuration backend,
Given a model identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
:param model_config: Model configuration, as returned by ModelConfigRecordStore
@ -115,34 +87,32 @@ class AnyModelLoader:
# this tracks the loader subclasses
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@inject
def __init__(
self,
store: ModelRecordServiceBase,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
convert_cache: ModelConvertCacheBase,
):
"""Store the provided ModelRecordServiceBase and empty the registry."""
self._store = store
"""Initialize AnyModelLoader with its dependencies."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its key.
@property
def ram_cache(self) -> ModelCacheBase:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
Given a model key identified in the model configuration backend,
return a ModelInfo object that can be used to retrieve the model.
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
"""
Return a model given its configuration.
:param key: model key, as known to the config backend
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_config = self._store.get_model(key)
implementation = self.__class__.get_implementation(
base=model_config.base, type=model_config.type, format=model_config.format
)
@ -165,7 +135,7 @@ class AnyModelLoader:
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
)
return implementation
@ -176,18 +146,10 @@ class AnyModelLoader:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
print("Registering class", subclass.__name__)
print("DEBUG: Registering class", subclass.__name__)
key = cls._to_registry_key(base, type, format)
cls._registry[key] = subclass
return subclass
return decorator
# in _init__.py will call something like
# def configure_loader_dependencies(binder):
# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton)
# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton)
# etc
# injector = Injector(configure_loader_dependencies)
# loader = injector.get(ModelFactory)

View File

@ -8,15 +8,14 @@ from typing import Any, Dict, Optional, Tuple
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from injector import inject
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -35,7 +34,6 @@ class ConfigLoader(ConfigMixin):
class ModelLoader(ModelLoaderBase):
"""Default implementation of ModelLoaderBase."""
@inject # can inject instances of each of the classes in the call signature
def __init__(
self,
app_config: InvokeAIAppConfig,
@ -87,18 +85,15 @@ class ModelLoader(ModelLoaderBase):
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> Path:
if not self._needs_conversion(config):
return model_path
cache_path: Path = self._convert_cache.cache_path(config.key)
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
cache_path: Path = self._convert_cache.cache_path(config.key)
if cache_path.exists():
return cache_path
return self._convert_model(config, model_path, cache_path)
self._convert_model(model_path, cache_path)
return cache_path
def _needs_conversion(self, config: AnyModelConfig) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
return False
def _load_if_needed(
@ -133,7 +128,7 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _convert_model(self, model_path: Path, cache_path: Path) -> None:
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(

View File

@ -0,0 +1,5 @@
"""Init file for RamCache."""
from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache
_all__ = ['ModelCacheBase', 'ModelCache']

View File

@ -10,34 +10,41 @@ model will be cleared and (re)loaded from disk when next needed.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, Optional
from typing import Dict, Optional, TypeVar, Generic
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
from invokeai.backend.model_manager import AnyModel, SubModelType
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@abstractmethod
def lock(self) -> AnyModel:
"""Lock the contained model and move it into VRAM."""
pass
@abstractmethod
def unlock(self) -> None:
"""Unlock the contained model, and remove it from VRAM."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
"""Return the model."""
pass
T = TypeVar("T")
@dataclass
class CacheStats(object):
"""Data object to record statistics on cache hits/misses."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
@dataclass
class CacheRecord:
class CacheRecord(Generic[T]):
"""Elements of the cache."""
key: str
model: AnyModel
model: T
size: int
loaded: bool = False
_locks: int = 0
def lock(self) -> None:
@ -55,7 +62,7 @@ class CacheRecord:
return self._locks > 0
class ModelCacheBase(ABC):
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@property
@ -76,8 +83,14 @@ class ModelCacheBase(ABC):
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@property
@abstractmethod
def offload_unlocked_models(self) -> None:
def max_cache_size(self) -> float:
"""Return true if the cache is configured to lazily offload models in VRAM."""
pass
@abstractmethod
def offload_unlocked_models(self, size_required: int) -> None:
"""Offload from VRAM any models not actively in use."""
pass
@ -101,7 +114,7 @@ class ModelCacheBase(ABC):
def put(
self,
key: str,
model: AnyModel,
model: T,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
@ -134,11 +147,6 @@ class ModelCacheBase(ABC):
"""Get the total size of the models currently cached."""
pass
@abstractmethod
def get_stats(self) -> CacheStats:
"""Return cache hit/miss/size statistics."""
pass
@abstractmethod
def print_cuda_stats(self) -> None:
"""Log debugging information on CUDA usage."""

View File

@ -18,6 +18,7 @@ context. Use like this:
"""
import gc
import math
import time
from contextlib import suppress
@ -26,14 +27,14 @@ from typing import Any, Dict, List, Optional
import torch
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_locker import ModelLockerBase, ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -52,7 +53,7 @@ GIG = 1073741824
MB = 2**20
class ModelCache(ModelCacheBase):
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
def __init__(
@ -92,62 +93,9 @@ class ModelCache(ModelCacheBase):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
self._cached_models: Dict[str, CacheRecord] = {}
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> Any:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models()
self._cache.print_cuda_stats()
@property
def logger(self) -> Logger:
"""Return the logger used by the cache."""
@ -168,6 +116,11 @@ class ModelCache(ModelCacheBase):
"""Return the exection device (e.g. "cuda" for VRAM)."""
return self._execution_device
@property
def max_cache_size(self) -> float:
"""Return the cap on cache size."""
return self._max_cache_size
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
@ -207,18 +160,18 @@ class ModelCache(ModelCacheBase):
"""
Retrieve model using key and optional submodel_type.
This may return an UnknownModelException if the model is not in the cache.
This may return an IndexError if the model is not in the cache.
"""
key = self._make_cache_key(key, submodel_type)
if key not in self._cached_models:
raise UnknownModelException
raise IndexError(f"The model with key {key} is not in the cache.")
# this moves the entry to the top (right end) of the stack
with suppress(Exception):
self._cache_stack.remove(key)
self._cache_stack.append(key)
cache_entry = self._cached_models[key]
return self.ModelLocker(
return ModelLocker(
cache=self,
cache_entry=cache_entry,
)
@ -234,19 +187,19 @@ class ModelCache(ModelCacheBase):
else:
return model_key
def offload_unlocked_models(self) -> None:
def offload_unlocked_models(self, size_required: int) -> None:
"""Move any unused models from VRAM."""
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
@ -305,28 +258,111 @@ class ModelCache(ModelCacheBase):
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
ram = "%4.2fG" % self.cache_size()
ram = "%4.2fG" % (self.cache_size() / GIG)
cached_models = 0
loaded_models = 0
locked_models = 0
in_ram_models = 0
in_vram_models = 0
locked_in_vram_models = 0
for cache_record in self._cached_models.values():
cached_models += 1
assert hasattr(cache_record.model, "device")
if cache_record.model.device is self.storage_device:
loaded_models += 1
if cache_record.model.device == self.storage_device:
in_ram_models += 1
else:
in_vram_models += 1
if cache_record.locked:
locked_models += 1
locked_in_vram_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
)
def get_stats(self) -> CacheStats:
"""Return cache hit/miss/size statistics."""
raise NotImplementedError
def make_room(self, size: int) -> None:
def make_room(self, model_size: int) -> None:
"""Make enough room in the cache to accommodate a new model of indicated size."""
raise NotImplementedError
# calculate how much memory this model will require
# multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = self.cache_size()
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
del self._cache_stack[pos]
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")

View File

@ -0,0 +1,59 @@
"""
Base class and implementation of a class that moves models in and out of VRAM.
"""
from abc import ABC, abstractmethod
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
def __init__(self, cache: ModelCacheBase[AnyModel], cache_entry: CacheRecord[AnyModel]):
"""
Initialize the model locker.
:param cache: The ModelCache object
:param cache_entry: The entry in the model cache
"""
self._cache = cache
self._cache_entry = cache_entry
@property
def model(self) -> AnyModel:
"""Return the model without moving it around."""
return self._cache_entry.model
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
return self.model
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
self._cache_entry.lock()
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except Exception:
self._cache_entry.unlock()
raise
return self.model
def unlock(self) -> None:
"""Call upon exit from context."""
if not hasattr(self.model, "to"):
return
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@ -0,0 +1,3 @@
"""
Init file for model_loaders.
"""

View File

@ -0,0 +1,83 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
import safetensors
from omegaconf import OmegaConf, DictConfig
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else None
result: AnyModel = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
print(f'DEBUG: last_modified={config.last_modified}')
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
if config.format != ModelFormat.Checkpoint:
return False
elif dest_path.exists() \
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
return False
else:
return True
def _convert_model(self,
config: AnyModelConfig,
weights_path: Path,
output_path: Path
) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
dtype = torch_dtype()
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig)
print(f'DEBUG: CONVERTIGN')
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
)
vae_model.to(dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
return output_path

View File

@ -48,6 +48,9 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder

View File

@ -1,31 +0,0 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
class VaeDiffusersModel(ModelLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> Dict[str, torch.Tensor]:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else ""
result: Dict[str, torch.Tensor] = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result