mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add ram cache module and support files
This commit is contained in:
parent
a1307b9f2e
commit
5c2884569e
@ -152,6 +152,7 @@ class _DiffusersConfig(ModelConfigBase):
|
|||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
class LoRAConfig(ModelConfigBase):
|
class LoRAConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
@ -179,6 +180,7 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
|
|||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
|
|
||||||
@ -214,6 +216,7 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
|||||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||||
upcast_attention: bool = False
|
upcast_attention: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ONNXSD1Config(_MainConfig):
|
class ONNXSD1Config(_MainConfig):
|
||||||
"""Model config for ONNX format models based on sd-1."""
|
"""Model config for ONNX format models based on sd-1."""
|
||||||
|
|
||||||
|
0
invokeai/backend/model_manager/load/__init__.py
Normal file
0
invokeai/backend/model_manager/load/__init__.py
Normal file
193
invokeai/backend/model_manager/load/load_base.py
Normal file
193
invokeai/backend/model_manager/load/load_base.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""
|
||||||
|
Base class for model loading in InvokeAI.
|
||||||
|
|
||||||
|
Use like this:
|
||||||
|
|
||||||
|
loader = AnyModelLoader(...)
|
||||||
|
loaded_model = loader.get_model('019ab39adfa1840455')
|
||||||
|
with loaded_model as model: # context manager moves model into VRAM
|
||||||
|
# do something with loaded_model
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
from injector import inject
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
|
||||||
|
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
from invokeai.backend.model_manager.ram_cache import ModelCacheBase
|
||||||
|
|
||||||
|
AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLockerBase(ABC):
|
||||||
|
"""Base class for the model locker used by the loader."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def lock(self) -> None:
|
||||||
|
"""Lock the contained model and move it into VRAM."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unlock(self) -> None:
|
||||||
|
"""Unlock the contained model, and remove it from VRAM."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def model(self) -> AnyModel:
|
||||||
|
"""Return the model."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadedModel:
|
||||||
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
|
config: AnyModelConfig
|
||||||
|
locker: ModelLockerBase
|
||||||
|
|
||||||
|
def __enter__(self) -> AnyModel: # I think load_file() always returns a dict
|
||||||
|
"""Context entry."""
|
||||||
|
self.locker.lock()
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Context exit."""
|
||||||
|
self.locker.unlock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> AnyModel:
|
||||||
|
"""Return the model without locking it."""
|
||||||
|
return self.locker.model()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLoaderBase(ABC):
|
||||||
|
"""Abstract base class for loading models into RAM/VRAM."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
ram_cache: ModelCacheBase,
|
||||||
|
convert_cache: ModelConvertCacheBase,
|
||||||
|
):
|
||||||
|
"""Initialize the loader."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Return a model given its key.
|
||||||
|
|
||||||
|
Given a model key identified in the model configuration backend,
|
||||||
|
return a ModelInfo object that can be used to retrieve the model.
|
||||||
|
|
||||||
|
:param model_config: Model configuration, as returned by ModelConfigRecordStore
|
||||||
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_size_fs(
|
||||||
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> int:
|
||||||
|
"""Return size in bytes of the model, calculated before loading."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# TO DO: Better name?
|
||||||
|
class AnyModelLoader:
|
||||||
|
"""This class manages the model loaders and invokes the correct one to load a model of given base and type."""
|
||||||
|
|
||||||
|
# this tracks the loader subclasses
|
||||||
|
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
store: ModelRecordServiceBase,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
ram_cache: ModelCacheBase,
|
||||||
|
convert_cache: ModelConvertCacheBase,
|
||||||
|
):
|
||||||
|
"""Store the provided ModelRecordServiceBase and empty the registry."""
|
||||||
|
self._store = store
|
||||||
|
self._app_config = app_config
|
||||||
|
self._logger = logger
|
||||||
|
self._ram_cache = ram_cache
|
||||||
|
self._convert_cache = convert_cache
|
||||||
|
|
||||||
|
def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Return a model given its key.
|
||||||
|
|
||||||
|
Given a model key identified in the model configuration backend,
|
||||||
|
return a ModelInfo object that can be used to retrieve the model.
|
||||||
|
|
||||||
|
:param key: model key, as known to the config backend
|
||||||
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
|
"""
|
||||||
|
model_config = self._store.get_model(key)
|
||||||
|
implementation = self.__class__.get_implementation(
|
||||||
|
base=model_config.base, type=model_config.type, format=model_config.format
|
||||||
|
)
|
||||||
|
return implementation(
|
||||||
|
app_config=self._app_config,
|
||||||
|
logger=self._logger,
|
||||||
|
ram_cache=self._ram_cache,
|
||||||
|
convert_cache=self._convert_cache,
|
||||||
|
).load_model(model_config, submodel_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str:
|
||||||
|
return "-".join([base.value, type.value, format.value])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]:
|
||||||
|
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||||
|
key1 = cls._to_registry_key(base, type, format) # for a specific base type
|
||||||
|
key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any
|
||||||
|
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||||
|
if not implementation:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
|
||||||
|
)
|
||||||
|
return implementation
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||||
|
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||||
|
"""Define a decorator which registers the subclass of loader."""
|
||||||
|
|
||||||
|
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||||
|
print("Registering class", subclass.__name__)
|
||||||
|
key = cls._to_registry_key(base, type, format)
|
||||||
|
cls._registry[key] = subclass
|
||||||
|
return subclass
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# in _init__.py will call something like
|
||||||
|
# def configure_loader_dependencies(binder):
|
||||||
|
# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton)
|
||||||
|
# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton)
|
||||||
|
# etc
|
||||||
|
# injector = Injector(configure_loader_dependencies)
|
||||||
|
# loader = injector.get(ModelFactory)
|
168
invokeai/backend/model_manager/load/load_default.py
Normal file
168
invokeai/backend/model_manager/load/load_default.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Default implementation of model loading in InvokeAI."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from logging import Logger
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from diffusers import ModelMixin
|
||||||
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
|
from injector import inject
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
|
||||||
|
from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
|
||||||
|
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||||
|
from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader(ConfigMixin):
|
||||||
|
"""Subclass of ConfigMixin for loading diffusers configuration files."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
"""Load a diffusrs ConfigMixin configuration."""
|
||||||
|
cls.config_name = kwargs.pop("config_name")
|
||||||
|
# Diffusers doesn't provide typing info
|
||||||
|
return super().load_config(*args, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# TO DO: The loader is not thread safe!
|
||||||
|
class ModelLoader(ModelLoaderBase):
|
||||||
|
"""Default implementation of ModelLoaderBase."""
|
||||||
|
|
||||||
|
@inject # can inject instances of each of the classes in the call signature
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app_config: InvokeAIAppConfig,
|
||||||
|
logger: Logger,
|
||||||
|
ram_cache: ModelCacheBase,
|
||||||
|
convert_cache: ModelConvertCacheBase,
|
||||||
|
):
|
||||||
|
"""Initialize the loader."""
|
||||||
|
self._app_config = app_config
|
||||||
|
self._logger = logger
|
||||||
|
self._ram_cache = ram_cache
|
||||||
|
self._convert_cache = convert_cache
|
||||||
|
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||||
|
self._size: Optional[int] = None # model size
|
||||||
|
|
||||||
|
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||||
|
"""
|
||||||
|
Return a model given its configuration.
|
||||||
|
|
||||||
|
Given a model's configuration as returned by the ModelRecordConfigStore service,
|
||||||
|
return a LoadedModel object that can be used for inference.
|
||||||
|
|
||||||
|
:param model config: Configuration record for this model
|
||||||
|
:param submodel_type: an ModelType enum indicating the portion of
|
||||||
|
the model to retrieve (e.g. ModelType.Vae)
|
||||||
|
"""
|
||||||
|
if model_config.type == "main" and not submodel_type:
|
||||||
|
raise InvalidModelConfigException("submodel_type is required when loading a main model")
|
||||||
|
|
||||||
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||||
|
if is_submodel_override:
|
||||||
|
submodel_type = None
|
||||||
|
|
||||||
|
if not model_path.exists():
|
||||||
|
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
|
||||||
|
|
||||||
|
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
|
||||||
|
locker = self._load_if_needed(model_config, model_path, submodel_type)
|
||||||
|
return LoadedModel(config=model_config, locker=locker)
|
||||||
|
|
||||||
|
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
|
||||||
|
# and submodels!!!!
|
||||||
|
def _get_model_path(
|
||||||
|
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> Tuple[Path, bool]:
|
||||||
|
model_base = self._app_config.models_path
|
||||||
|
return ((model_base / config.path).resolve(), False)
|
||||||
|
|
||||||
|
def _convert_if_needed(
|
||||||
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> Path:
|
||||||
|
if not self._needs_conversion(config):
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
|
||||||
|
cache_path: Path = self._convert_cache.cache_path(config.key)
|
||||||
|
if cache_path.exists():
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
self._convert_model(model_path, cache_path)
|
||||||
|
return cache_path
|
||||||
|
|
||||||
|
def _needs_conversion(self, config: AnyModelConfig) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_if_needed(
|
||||||
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> ModelLockerBase:
|
||||||
|
# TO DO: This is not thread safe!
|
||||||
|
if self._ram_cache.exists(config.key, submodel_type):
|
||||||
|
return self._ram_cache.get(config.key, submodel_type)
|
||||||
|
|
||||||
|
model_variant = getattr(config, "repo_variant", None)
|
||||||
|
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||||
|
|
||||||
|
# This is where the model is actually loaded!
|
||||||
|
with skip_torch_weight_init():
|
||||||
|
loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type)
|
||||||
|
|
||||||
|
self._ram_cache.put(
|
||||||
|
config.key,
|
||||||
|
submodel_type=submodel_type,
|
||||||
|
model=loaded_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._ram_cache.get(config.key, submodel_type)
|
||||||
|
|
||||||
|
def get_size_fs(
|
||||||
|
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> int:
|
||||||
|
"""Get the size of the model on disk."""
|
||||||
|
return calc_model_size_by_fs(
|
||||||
|
model_path=model_path,
|
||||||
|
subfolder=submodel_type.value if submodel_type else None,
|
||||||
|
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_model(self, model_path: Path, cache_path: Path) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> AnyModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
|
||||||
|
return ConfigLoader.load_config(model_path, config_name=config_name)
|
||||||
|
|
||||||
|
# TO DO: Add exception handling
|
||||||
|
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
|
||||||
|
if module in ["diffusers", "transformers"]:
|
||||||
|
res_type = sys.modules[module]
|
||||||
|
else:
|
||||||
|
res_type = sys.modules["diffusers"].pipelines
|
||||||
|
result: ModelMixin = getattr(res_type, class_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# TO DO: Add exception handling
|
||||||
|
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||||
|
if submodel_type:
|
||||||
|
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||||
|
module, class_name = config[submodel_type.value]
|
||||||
|
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||||
|
else:
|
||||||
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
|
class_name = config["_class_name"]
|
||||||
|
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
100
invokeai/backend/model_manager/load/memory_snapshot.py
Normal file
100
invokeai/backend/model_manager/load/memory_snapshot.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import gc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||||
|
|
||||||
|
GB = 2**30 # 1 GB
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySnapshot:
|
||||||
|
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
|
||||||
|
|
||||||
|
def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
|
||||||
|
"""Initialize a MemorySnapshot.
|
||||||
|
|
||||||
|
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_ram (int): CPU RAM used by the current process.
|
||||||
|
vram (Optional[int]): VRAM used by torch.
|
||||||
|
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
|
||||||
|
"""
|
||||||
|
self.process_ram = process_ram
|
||||||
|
self.vram = vram
|
||||||
|
self.malloc_info = malloc_info
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def capture(cls, run_garbage_collector: bool = True) -> Self:
|
||||||
|
"""Capture and return a MemorySnapshot.
|
||||||
|
|
||||||
|
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
|
||||||
|
usage. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemorySnapshot
|
||||||
|
"""
|
||||||
|
if run_garbage_collector:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
|
||||||
|
# supported on all platforms.
|
||||||
|
process_ram = psutil.Process().memory_info().rss
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
vram = torch.cuda.memory_allocated()
|
||||||
|
else:
|
||||||
|
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
|
||||||
|
# time to test it properly.
|
||||||
|
vram = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
malloc_info = LibcUtil().mallinfo2() # type: ignore
|
||||||
|
except (OSError, AttributeError):
|
||||||
|
# OSError: This is expected in environments that do not have the 'libc.so.6' shared library.
|
||||||
|
# AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33)
|
||||||
|
# TODO: Does `mallinfo` work?
|
||||||
|
malloc_info = None
|
||||||
|
|
||||||
|
return cls(process_ram, vram, malloc_info)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
|
||||||
|
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
||||||
|
|
||||||
|
def get_msg_line(prefix: str, val1: int, val2: int) -> str:
|
||||||
|
diff = val2 - val1
|
||||||
|
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
|
||||||
|
|
||||||
|
msg = ""
|
||||||
|
|
||||||
|
if snapshot_1 is None or snapshot_2 is None:
|
||||||
|
return msg
|
||||||
|
|
||||||
|
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
|
||||||
|
|
||||||
|
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||||
|
msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)
|
||||||
|
|
||||||
|
msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)
|
||||||
|
|
||||||
|
msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)
|
||||||
|
|
||||||
|
libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
|
||||||
|
libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
|
||||||
|
msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)
|
||||||
|
|
||||||
|
libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
|
||||||
|
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
|
||||||
|
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
|
||||||
|
|
||||||
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
|
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
|
||||||
|
|
||||||
|
return msg
|
109
invokeai/backend/model_manager/load/model_util.py
Normal file
109
invokeai/backend/model_manager/load/model_util.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright (c) 2024 The InvokeAI Development Team
|
||||||
|
"""Various utility functions needed by the loader and caching system."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int:
|
||||||
|
"""Get size of a model in memory in bytes."""
|
||||||
|
if isinstance(model, DiffusionPipeline):
|
||||||
|
return _calc_pipeline_by_data(model)
|
||||||
|
elif isinstance(model, torch.nn.Module):
|
||||||
|
return _calc_model_by_data(model)
|
||||||
|
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||||
|
return _calc_onnx_model_by_data(model)
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
|
||||||
|
res = 0
|
||||||
|
assert hasattr(pipeline, "components")
|
||||||
|
for submodel_key in pipeline.components.keys():
|
||||||
|
submodel = getattr(pipeline, submodel_key)
|
||||||
|
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||||
|
res += _calc_model_by_data(submodel)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_model_by_data(model: torch.nn.Module) -> int:
|
||||||
|
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||||
|
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||||
|
mem: int = mem_params + mem_bufs # in bytes
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
|
||||||
|
tensor_size = model.tensors.size() * 2 # The session doubles this
|
||||||
|
mem = tensor_size # in bytes
|
||||||
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
||||||
|
"""Estimate the size of a model on disk in bytes."""
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = model_path / subfolder
|
||||||
|
|
||||||
|
# this can happen when, for example, the safety checker is not downloaded.
|
||||||
|
if not model_path.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
|
||||||
|
|
||||||
|
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
|
||||||
|
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
|
||||||
|
other_files = set(all_files) - fp16_files - bit8_files
|
||||||
|
|
||||||
|
if variant is None:
|
||||||
|
files = other_files
|
||||||
|
elif variant == "fp16":
|
||||||
|
files = fp16_files
|
||||||
|
elif variant == "8bit":
|
||||||
|
files = bit8_files
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown variant: {variant}")
|
||||||
|
|
||||||
|
# try read from index if exists
|
||||||
|
index_postfix = ".index.json"
|
||||||
|
if variant is not None:
|
||||||
|
index_postfix = f".index.{variant}.json"
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if not file.name.endswith(index_postfix):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
with open(model_path / file, "r") as f:
|
||||||
|
index_data = json.loads(f.read())
|
||||||
|
return int(index_data["metadata"]["total_size"])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# calculate files size if there is no index file
|
||||||
|
formats = [
|
||||||
|
(".safetensors",), # safetensors
|
||||||
|
(".bin",), # torch
|
||||||
|
(".onnx", ".pb"), # onnx
|
||||||
|
(".msgpack",), # flax
|
||||||
|
(".ckpt",), # tf
|
||||||
|
(".h5",), # tf2
|
||||||
|
]
|
||||||
|
|
||||||
|
for file_format in formats:
|
||||||
|
model_files = [f for f in files if f.suffix in file_format]
|
||||||
|
if len(model_files) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model_size = 0
|
||||||
|
for model_file in model_files:
|
||||||
|
file_stats = (model_path / model_file).stat()
|
||||||
|
model_size += file_stats.st_size
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu
|
30
invokeai/backend/model_manager/load/optimizations.py
Normal file
30
invokeai/backend/model_manager/load/optimizations.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _no_op(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def skip_torch_weight_init():
|
||||||
|
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
|
||||||
|
to skip weight initialization.
|
||||||
|
|
||||||
|
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
|
||||||
|
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
|
||||||
|
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||||
|
monkey-patches common torch layers to skip the weight initialization step.
|
||||||
|
"""
|
||||||
|
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||||
|
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for torch_module in torch_modules:
|
||||||
|
torch_module.reset_parameters = _no_op
|
||||||
|
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True):
|
||||||
|
torch_module.reset_parameters = saved_function
|
145
invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py
Normal file
145
invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||||
|
# TODO: Add Stalker's proper name to copyright
|
||||||
|
"""
|
||||||
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||||
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||||
|
grows larger than a preset maximum, then the least recently used
|
||||||
|
model will be cleared and (re)loaded from disk when next needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import SubModelType
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats(object):
|
||||||
|
"""Data object to record statistics on cache hits/misses."""
|
||||||
|
|
||||||
|
hits: int = 0 # cache hits
|
||||||
|
misses: int = 0 # cache misses
|
||||||
|
high_watermark: int = 0 # amount of cache used
|
||||||
|
in_cache: int = 0 # number of models in cache
|
||||||
|
cleared: int = 0 # number of models cleared to make space
|
||||||
|
cache_size: int = 0 # total size of cache
|
||||||
|
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheRecord:
|
||||||
|
"""Elements of the cache."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
model: AnyModel
|
||||||
|
size: int
|
||||||
|
_locks: int = 0
|
||||||
|
|
||||||
|
def lock(self) -> None:
|
||||||
|
"""Lock this record."""
|
||||||
|
self._locks += 1
|
||||||
|
|
||||||
|
def unlock(self) -> None:
|
||||||
|
"""Unlock this record."""
|
||||||
|
self._locks -= 1
|
||||||
|
assert self._locks >= 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def locked(self) -> bool:
|
||||||
|
"""Return true if record is locked."""
|
||||||
|
return self._locks > 0
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCacheBase(ABC):
|
||||||
|
"""Virtual base class for RAM model cache."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def storage_device(self) -> torch.device:
|
||||||
|
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def execution_device(self) -> torch.device:
|
||||||
|
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def lazy_offloading(self) -> bool:
|
||||||
|
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def offload_unlocked_models(self) -> None:
|
||||||
|
"""Offload from VRAM any models not actively in use."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None:
|
||||||
|
"""Move model into the indicated device."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Return the logger used by the cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def make_room(self, size: int) -> None:
|
||||||
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
model: AnyModel,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store model under key and optional submodel_type."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> ModelLockerBase:
|
||||||
|
"""
|
||||||
|
Retrieve model locker object using key and optional submodel_type.
|
||||||
|
|
||||||
|
This may return an UnknownModelException if the model is not in the cache.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def exists(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""Get the total size of the models currently cached."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_stats(self) -> CacheStats:
|
||||||
|
"""Return cache hit/miss/size statistics."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def print_cuda_stats(self) -> None:
|
||||||
|
"""Log debugging information on CUDA usage."""
|
||||||
|
pass
|
@ -0,0 +1,332 @@
|
|||||||
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||||
|
# TODO: Add Stalker's proper name to copyright
|
||||||
|
"""
|
||||||
|
Manage a RAM cache of diffusion/transformer models for fast switching.
|
||||||
|
They are moved between GPU VRAM and CPU RAM as necessary. If the cache
|
||||||
|
grows larger than a preset maximum, then the least recently used
|
||||||
|
model will be cleared and (re)loaded from disk when next needed.
|
||||||
|
|
||||||
|
The cache returns context manager generators designed to load the
|
||||||
|
model into the GPU within the context, and unload outside the
|
||||||
|
context. Use like this:
|
||||||
|
|
||||||
|
cache = ModelCache(max_cache_size=7.5)
|
||||||
|
with cache.get_model('runwayml/stable-diffusion-1-5') as SD1,
|
||||||
|
cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||||
|
do_something_in_GPU(SD1,SD2)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.app.services.model_records import UnknownModelException
|
||||||
|
from invokeai.backend.model_manager import SubModelType
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase
|
||||||
|
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||||
|
from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
from torch import mps
|
||||||
|
|
||||||
|
# Maximum size of the cache, in gigs
|
||||||
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
|
||||||
|
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||||
|
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||||
|
|
||||||
|
# actual size of a gig
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
|
# Size of a MB in bytes.
|
||||||
|
MB = 2**20
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCache(ModelCacheBase):
|
||||||
|
"""Implementation of ModelCacheBase."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||||
|
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||||
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
|
precision: torch.dtype = torch.float16,
|
||||||
|
sequential_offload: bool = False,
|
||||||
|
lazy_offloading: bool = True,
|
||||||
|
sha_chunksize: int = 16777216,
|
||||||
|
log_memory_usage: bool = False,
|
||||||
|
logger: Optional[Logger] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the model RAM cache.
|
||||||
|
|
||||||
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
|
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||||
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
|
behaviour.
|
||||||
|
"""
|
||||||
|
# allow lazy offloading only when vram cache enabled
|
||||||
|
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||||
|
self._precision: torch.dtype = precision
|
||||||
|
self._max_cache_size: float = max_cache_size
|
||||||
|
self._max_vram_cache_size: float = max_vram_cache_size
|
||||||
|
self._execution_device: torch.device = execution_device
|
||||||
|
self._storage_device: torch.device = storage_device
|
||||||
|
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
|
self._log_memory_usage = log_memory_usage
|
||||||
|
|
||||||
|
# used for stats collection
|
||||||
|
self.stats = None
|
||||||
|
|
||||||
|
self._cached_models: Dict[str, CacheRecord] = {}
|
||||||
|
self._cache_stack: List[str] = []
|
||||||
|
|
||||||
|
class ModelLocker(ModelLockerBase):
|
||||||
|
"""Internal class that mediates movement in and out of GPU."""
|
||||||
|
|
||||||
|
def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord):
|
||||||
|
"""
|
||||||
|
Initialize the model locker.
|
||||||
|
|
||||||
|
:param cache: The ModelCache object
|
||||||
|
:param cache_entry: The entry in the model cache
|
||||||
|
"""
|
||||||
|
self._cache = cache
|
||||||
|
self._cache_entry = cache_entry
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> AnyModel:
|
||||||
|
"""Return the model without moving it around."""
|
||||||
|
return self._cache_entry.model
|
||||||
|
|
||||||
|
def lock(self) -> Any:
|
||||||
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
|
if not hasattr(self.model, "to"):
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
# NOTE that the model has to have the to() method in order for this code to move it into GPU!
|
||||||
|
self._cache_entry.lock()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._cache.lazy_offloading:
|
||||||
|
self._cache.offload_unlocked_models()
|
||||||
|
|
||||||
|
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||||
|
|
||||||
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||||
|
self._cache.print_cuda_stats()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
self._cache_entry.unlock()
|
||||||
|
raise
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def unlock(self) -> None:
|
||||||
|
"""Call upon exit from context."""
|
||||||
|
if not hasattr(self.model, "to"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._cache_entry.unlock()
|
||||||
|
if not self._cache.lazy_offloading:
|
||||||
|
self._cache.offload_unlocked_models()
|
||||||
|
self._cache.print_cuda_stats()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> Logger:
|
||||||
|
"""Return the logger used by the cache."""
|
||||||
|
return self._logger
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lazy_offloading(self) -> bool:
|
||||||
|
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||||
|
return self._lazy_offloading
|
||||||
|
|
||||||
|
@property
|
||||||
|
def storage_device(self) -> torch.device:
|
||||||
|
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||||
|
return self._storage_device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def execution_device(self) -> torch.device:
|
||||||
|
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||||
|
return self._execution_device
|
||||||
|
|
||||||
|
def cache_size(self) -> int:
|
||||||
|
"""Get the total size of the models currently cached."""
|
||||||
|
total = 0
|
||||||
|
for cache_record in self._cached_models.values():
|
||||||
|
total += cache_record.size
|
||||||
|
return total
|
||||||
|
|
||||||
|
def exists(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||||
|
key = self._make_cache_key(key, submodel_type)
|
||||||
|
return key in self._cached_models
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
model: AnyModel,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store model under key and optional submodel_type."""
|
||||||
|
key = self._make_cache_key(key, submodel_type)
|
||||||
|
assert key not in self._cached_models
|
||||||
|
|
||||||
|
loaded_model_size = calc_model_size_by_data(model)
|
||||||
|
cache_record = CacheRecord(key, model, loaded_model_size)
|
||||||
|
self._cached_models[key] = cache_record
|
||||||
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> ModelLockerBase:
|
||||||
|
"""
|
||||||
|
Retrieve model using key and optional submodel_type.
|
||||||
|
|
||||||
|
This may return an UnknownModelException if the model is not in the cache.
|
||||||
|
"""
|
||||||
|
key = self._make_cache_key(key, submodel_type)
|
||||||
|
if key not in self._cached_models:
|
||||||
|
raise UnknownModelException
|
||||||
|
|
||||||
|
# this moves the entry to the top (right end) of the stack
|
||||||
|
with suppress(Exception):
|
||||||
|
self._cache_stack.remove(key)
|
||||||
|
self._cache_stack.append(key)
|
||||||
|
cache_entry = self._cached_models[key]
|
||||||
|
return self.ModelLocker(
|
||||||
|
cache=self,
|
||||||
|
cache_entry=cache_entry,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
|
if self._log_memory_usage:
|
||||||
|
return MemorySnapshot.capture()
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str:
|
||||||
|
if submodel_type:
|
||||||
|
return f"{model_key}:{submodel_type.value}"
|
||||||
|
else:
|
||||||
|
return model_key
|
||||||
|
|
||||||
|
def offload_unlocked_models(self) -> None:
|
||||||
|
"""Move any unused models from VRAM."""
|
||||||
|
reserved = self._max_vram_cache_size * GIG
|
||||||
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
|
if vram_in_use <= reserved:
|
||||||
|
break
|
||||||
|
if not cache_entry.locked:
|
||||||
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
|
|
||||||
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
|
# TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size
|
||||||
|
# for printing debugging messages. Revisit whether this is necessary
|
||||||
|
def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
|
||||||
|
"""Move model into the indicated device."""
|
||||||
|
# These attributes are not in the base class but in derived classes
|
||||||
|
assert hasattr(cache_entry.model, "device")
|
||||||
|
assert hasattr(cache_entry.model, "to")
|
||||||
|
|
||||||
|
source_device = cache_entry.model.device
|
||||||
|
|
||||||
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
|
||||||
|
# multi-GPU.
|
||||||
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
|
return
|
||||||
|
|
||||||
|
start_model_to_time = time.time()
|
||||||
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
|
cache_entry.model.to(target_device)
|
||||||
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
|
end_model_to_time = time.time()
|
||||||
|
self.logger.debug(
|
||||||
|
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||||
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
|
||||||
|
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
|
||||||
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
snapshot_before is not None
|
||||||
|
and snapshot_after is not None
|
||||||
|
and snapshot_before.vram is not None
|
||||||
|
and snapshot_after.vram is not None
|
||||||
|
):
|
||||||
|
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||||
|
|
||||||
|
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||||
|
if not math.isclose(
|
||||||
|
vram_change,
|
||||||
|
cache_entry.size,
|
||||||
|
rel_tol=0.1,
|
||||||
|
abs_tol=10 * MB,
|
||||||
|
):
|
||||||
|
self.logger.debug(
|
||||||
|
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||||
|
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||||
|
" estimated size may be incorrect. Estimated model size:"
|
||||||
|
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||||
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def print_cuda_stats(self) -> None:
|
||||||
|
"""Log CUDA diagnostics."""
|
||||||
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
|
ram = "%4.2fG" % self.cache_size()
|
||||||
|
|
||||||
|
cached_models = 0
|
||||||
|
loaded_models = 0
|
||||||
|
locked_models = 0
|
||||||
|
for cache_record in self._cached_models.values():
|
||||||
|
cached_models += 1
|
||||||
|
assert hasattr(cache_record.model, "device")
|
||||||
|
if cache_record.model.device is self.storage_device:
|
||||||
|
loaded_models += 1
|
||||||
|
if cache_record.locked:
|
||||||
|
locked_models += 1
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
|
||||||
|
f" {cached_models}/{loaded_models}/{locked_models}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_stats(self) -> CacheStats:
|
||||||
|
"""Return cache hit/miss/size statistics."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def make_room(self, size: int) -> None:
|
||||||
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
|
raise NotImplementedError
|
31
invokeai/backend/model_manager/load/vae.py
Normal file
31
invokeai/backend/model_manager/load/vae.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
"""Class for VAE model loading in InvokeAI."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
|
||||||
|
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||||
|
|
||||||
|
|
||||||
|
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
||||||
|
class VaeDiffusersModel(ModelLoader):
|
||||||
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
|
def _load_model(
|
||||||
|
self,
|
||||||
|
model_path: Path,
|
||||||
|
model_variant: Optional[ModelRepoVariant] = None,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if submodel_type is not None:
|
||||||
|
raise Exception("There are no submodels in VAEs")
|
||||||
|
vae_class = self._get_hf_load_class(model_path)
|
||||||
|
variant = model_variant.value if model_variant else ""
|
||||||
|
result: Dict[str, torch.Tensor] = vae_class.from_pretrained(
|
||||||
|
model_path, torch_dtype=self._torch_dtype, variant=variant
|
||||||
|
) # type: ignore
|
||||||
|
return result
|
216
invokeai/backend/model_manager/onnx_runtime.py
Normal file
216
invokeai/backend/model_manager/onnx_runtime.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
# Copyright (c) 2024 The InvokeAI Development Team
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
from onnx import numpy_helper
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
|
||||||
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE FROM LS: This was copied from Stalker's original implementation.
|
||||||
|
# I have not yet gone through and fixed all the type hints
|
||||||
|
class IAIOnnxRuntimeModel:
|
||||||
|
class _tensor_access:
|
||||||
|
def __init__(self, model): # type: ignore
|
||||||
|
self.model = model
|
||||||
|
self.indexes = {}
|
||||||
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str): # type: ignore
|
||||||
|
value = self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
return numpy_helper.to_array(value)
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: np.ndarray): # type: ignore
|
||||||
|
new_node = numpy_helper.from_array(value)
|
||||||
|
# set_external_data(new_node, location="in-memory-location")
|
||||||
|
new_node.name = key
|
||||||
|
# new_node.ClearField("raw_data")
|
||||||
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||||
|
# self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return self.indexes[key] in self.model.proto.graph.initializer
|
||||||
|
|
||||||
|
def items(self) -> List[Tuple[str, Any]]: # fixme
|
||||||
|
raise NotImplementedError("tensor.items")
|
||||||
|
# return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
return list(self.indexes.keys())
|
||||||
|
|
||||||
|
def values(self) -> List[Any]: # fixme
|
||||||
|
raise NotImplementedError("tensor.values")
|
||||||
|
# return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def size(self) -> int:
|
||||||
|
bytesSum = 0
|
||||||
|
for node in self.model.proto.graph.initializer:
|
||||||
|
bytesSum += sys.getsizeof(node.raw_data)
|
||||||
|
return bytesSum
|
||||||
|
|
||||||
|
class _access_helper:
|
||||||
|
def __init__(self, raw_proto): # type: ignore
|
||||||
|
self.indexes = {}
|
||||||
|
self.raw_proto = raw_proto
|
||||||
|
for idx, obj in enumerate(raw_proto):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str): # type: ignore
|
||||||
|
return self.raw_proto[self.indexes[key]]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value): # type: ignore
|
||||||
|
index = self.indexes[key]
|
||||||
|
del self.raw_proto[index]
|
||||||
|
self.raw_proto.insert(index, value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str) -> bool:
|
||||||
|
return key in self.indexes
|
||||||
|
|
||||||
|
def items(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
return list(self.indexes.keys())
|
||||||
|
|
||||||
|
def values(self) -> List[Any]: # fixme
|
||||||
|
return list(self.raw_proto)
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
|
self.path = model_path
|
||||||
|
self.session = None
|
||||||
|
self.provider = provider
|
||||||
|
"""
|
||||||
|
self.data_path = self.path + "_data"
|
||||||
|
if not os.path.exists(self.data_path):
|
||||||
|
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||||
|
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||||
|
del tmp_proto
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
# self.data = dict()
|
||||||
|
# for tensor in self.proto.graph.initializer:
|
||||||
|
# name = tensor.name
|
||||||
|
|
||||||
|
# if tensor.HasField("raw_data"):
|
||||||
|
# npt = numpy_helper.to_array(tensor)
|
||||||
|
# orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
# # self.data[name] = orv
|
||||||
|
# # set_external_data(tensor, location="in-memory-location")
|
||||||
|
# tensor.name = name
|
||||||
|
# # tensor.ClearField("raw_data")
|
||||||
|
|
||||||
|
self.nodes = self._access_helper(self.proto.graph.node) # type: ignore
|
||||||
|
# self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||||
|
# print(self.proto.graph.input)
|
||||||
|
# print(self.proto.graph.initializer)
|
||||||
|
|
||||||
|
self.tensors = self._tensor_access(self) # type: ignore
|
||||||
|
|
||||||
|
# TODO: integrate with model manager/cache
|
||||||
|
def create_session(self, height=None, width=None):
|
||||||
|
if self.session is None or self.session_width != width or self.session_height != height:
|
||||||
|
# onnx.save(self.proto, "tmp.onnx")
|
||||||
|
# onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
|
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||||
|
# (trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||||
|
sess = SessionOptions()
|
||||||
|
# self._external_data.update(**external_data)
|
||||||
|
# sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||||
|
# sess.enable_profiling = True
|
||||||
|
|
||||||
|
# sess.intra_op_num_threads = 1
|
||||||
|
# sess.inter_op_num_threads = 1
|
||||||
|
# sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
# sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
# sess.enable_cpu_mem_arena = True
|
||||||
|
# sess.enable_mem_pattern = True
|
||||||
|
# sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code
|
||||||
|
self.session_height = height
|
||||||
|
self.session_width = width
|
||||||
|
if height and width:
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_batch", 2)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_channels", 4)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_hidden_batch", 2)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width)
|
||||||
|
sess.add_free_dimension_override_by_name("unet_time_batch", 1)
|
||||||
|
providers = []
|
||||||
|
if self.provider:
|
||||||
|
providers.append(self.provider)
|
||||||
|
else:
|
||||||
|
providers = get_available_providers()
|
||||||
|
if "TensorrtExecutionProvider" in providers:
|
||||||
|
providers.remove("TensorrtExecutionProvider")
|
||||||
|
try:
|
||||||
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
# self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
# self.io_binding = self.session.io_binding()
|
||||||
|
|
||||||
|
def release_session(self):
|
||||||
|
self.session = None
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
return
|
||||||
|
|
||||||
|
def __call__(self, **kwargs):
|
||||||
|
if self.session is None:
|
||||||
|
raise Exception("You should call create_session before running model")
|
||||||
|
|
||||||
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||||
|
# output_names = self.session.get_outputs()
|
||||||
|
# for k in inputs:
|
||||||
|
# self.io_binding.bind_cpu_input(k, inputs[k])
|
||||||
|
# for name in output_names:
|
||||||
|
# self.io_binding.bind_output(name.name)
|
||||||
|
# self.session.run_with_iobinding(self.io_binding, None)
|
||||||
|
# return self.io_binding.copy_outputs_to_cpu()
|
||||||
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
|
# compatability with diffusers load code
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_id: Union[str, Path],
|
||||||
|
subfolder: Optional[Union[str, Path]] = None,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
sess_options: Optional["SessionOptions"] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any: # fixme
|
||||||
|
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||||
|
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
model_path = model_id
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
model_path = os.path.join(model_path, file_name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model_path = model_id
|
||||||
|
|
||||||
|
# load model from local directory
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise Exception(f"Model not found: {model_path}")
|
||||||
|
|
||||||
|
# TODO: session options
|
||||||
|
return cls(str(model_path), provider=provider)
|
@ -18,9 +18,9 @@ from .config import (
|
|||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
|
ModelRepoVariant,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
ModelRepoVariant,
|
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import FastModelHash
|
||||||
@ -483,8 +483,8 @@ class FolderProbeBase(ProbeBase):
|
|||||||
|
|
||||||
def get_repo_variant(self) -> ModelRepoVariant:
|
def get_repo_variant(self) -> ModelRepoVariant:
|
||||||
# get all files ending in .bin or .safetensors
|
# get all files ending in .bin or .safetensors
|
||||||
weight_files = list(self.model_path.glob('**/*.safetensors'))
|
weight_files = list(self.model_path.glob("**/*.safetensors"))
|
||||||
weight_files.extend(list(self.model_path.glob('**/*.bin')))
|
weight_files.extend(list(self.model_path.glob("**/*.bin")))
|
||||||
for x in weight_files:
|
for x in weight_files:
|
||||||
if ".fp16" in x.suffixes:
|
if ".fp16" in x.suffixes:
|
||||||
return ModelRepoVariant.FP16
|
return ModelRepoVariant.FP16
|
||||||
@ -496,6 +496,7 @@ class FolderProbeBase(ProbeBase):
|
|||||||
return ModelRepoVariant.ONNX
|
return ModelRepoVariant.ONNX
|
||||||
return ModelRepoVariant.DEFAULT
|
return ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||||
@ -540,7 +541,6 @@ class PipelineFolderProbe(FolderProbeBase):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VaeFolderProbe(FolderProbeBase):
|
class VaeFolderProbe(FolderProbeBase):
|
||||||
|
@ -21,9 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
|||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
assert base_type == expected_type
|
assert base_type == expected_type
|
||||||
repo_variant = probe.get_repo_variant()
|
repo_variant = probe.get_repo_variant()
|
||||||
assert repo_variant == 'default'
|
assert repo_variant == "default"
|
||||||
|
|
||||||
|
|
||||||
def test_repo_variant(datadir: Path):
|
def test_repo_variant(datadir: Path):
|
||||||
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16")
|
||||||
repo_variant = probe.get_repo_variant()
|
repo_variant = probe.get_repo_variant()
|
||||||
assert repo_variant == 'fp16'
|
assert repo_variant == "fp16"
|
||||||
|
Loading…
Reference in New Issue
Block a user