From b8e875bb7359957ed542561589ac65e5facf6465 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 31 Jan 2024 23:37:59 -0500 Subject: [PATCH] add ram cache module and support files --- invokeai/backend/model_manager/config.py | 3 + .../backend/model_manager/load/__init__.py | 0 .../backend/model_manager/load/load_base.py | 193 ++++++++++ .../model_manager/load/load_default.py | 168 +++++++++ .../model_manager/load/memory_snapshot.py | 100 ++++++ .../backend/model_manager/load/model_util.py | 109 ++++++ .../model_manager/load/optimizations.py | 30 ++ .../model_manager/load/ram_cache/__init__.py | 0 .../load/ram_cache/ram_cache_base.py | 145 ++++++++ .../load/ram_cache/ram_cache_default.py | 332 ++++++++++++++++++ invokeai/backend/model_manager/load/vae.py | 31 ++ .../backend/model_manager/onnx_runtime.py | 216 ++++++++++++ invokeai/backend/model_manager/probe.py | 8 +- tests/test_model_probe.py | 5 +- 14 files changed, 1334 insertions(+), 6 deletions(-) create mode 100644 invokeai/backend/model_manager/load/__init__.py create mode 100644 invokeai/backend/model_manager/load/load_base.py create mode 100644 invokeai/backend/model_manager/load/load_default.py create mode 100644 invokeai/backend/model_manager/load/memory_snapshot.py create mode 100644 invokeai/backend/model_manager/load/model_util.py create mode 100644 invokeai/backend/model_manager/load/optimizations.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/__init__.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py create mode 100644 invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py create mode 100644 invokeai/backend/model_manager/load/vae.py create mode 100644 invokeai/backend/model_manager/onnx_runtime.py diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index b4685caf10..338669c873 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -152,6 +152,7 @@ class _DiffusersConfig(ModelConfigBase): format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT + class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" @@ -179,6 +180,7 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -214,6 +216,7 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False + class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py new file mode 100644 index 0000000000..7cb7222b71 --- /dev/null +++ b/invokeai/backend/model_manager/load/load_base.py @@ -0,0 +1,193 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +""" +Base class for model loading in InvokeAI. + +Use like this: + + loader = AnyModelLoader(...) + loaded_model = loader.get_model('019ab39adfa1840455') + with loaded_model as model: # context manager moves model into VRAM + # do something with loaded_model +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from logging import Logger +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Type, Union + +import torch +from diffusers import DiffusionPipeline +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.model_records import ModelRecordServiceBase +from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.model_manager.ram_cache import ModelCacheBase + +AnyModel = Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel] + + +class ModelLockerBase(ABC): + """Base class for the model locker used by the loader.""" + + @abstractmethod + def lock(self) -> None: + """Lock the contained model and move it into VRAM.""" + pass + + @abstractmethod + def unlock(self) -> None: + """Unlock the contained model, and remove it from VRAM.""" + pass + + @property + @abstractmethod + def model(self) -> AnyModel: + """Return the model.""" + pass + + +@dataclass +class LoadedModel: + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: AnyModelConfig + locker: ModelLockerBase + + def __enter__(self) -> AnyModel: # I think load_file() always returns a dict + """Context entry.""" + self.locker.lock() + return self.model + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Context exit.""" + self.locker.unlock() + + @property + def model(self) -> AnyModel: + """Return the model without locking it.""" + return self.locker.model() + + +class ModelLoaderBase(ABC): + """Abstract base class for loading models into RAM/VRAM.""" + + @abstractmethod + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + pass + + @abstractmethod + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param model_config: Model configuration, as returned by ModelConfigRecordStore + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + pass + + @abstractmethod + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Return size in bytes of the model, calculated before loading.""" + pass + + +# TO DO: Better name? +class AnyModelLoader: + """This class manages the model loaders and invokes the correct one to load a model of given base and type.""" + + # this tracks the loader subclasses + _registry: Dict[str, Type[ModelLoaderBase]] = {} + + @inject + def __init__( + self, + store: ModelRecordServiceBase, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Store the provided ModelRecordServiceBase and empty the registry.""" + self._store = store + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + + def get_model(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its key. + + Given a model key identified in the model configuration backend, + return a ModelInfo object that can be used to retrieve the model. + + :param key: model key, as known to the config backend + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + model_config = self._store.get_model(key) + implementation = self.__class__.get_implementation( + base=model_config.base, type=model_config.type, format=model_config.format + ) + return implementation( + app_config=self._app_config, + logger=self._logger, + ram_cache=self._ram_cache, + convert_cache=self._convert_cache, + ).load_model(model_config, submodel_type) + + @staticmethod + def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: + return "-".join([base.value, type.value, format.value]) + + @classmethod + def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]: + """Get subclass of ModelLoaderBase registered to handle base and type.""" + key1 = cls._to_registry_key(base, type, format) # for a specific base type + key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any + implementation = cls._registry.get(key1) or cls._registry.get(key2) + if not implementation: + raise NotImplementedError( + "No subclass of LoadedModel is registered for base={base}, type={type}, format={format}" + ) + return implementation + + @classmethod + def register( + cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any + ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: + """Define a decorator which registers the subclass of loader.""" + + def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: + print("Registering class", subclass.__name__) + key = cls._to_registry_key(base, type, format) + cls._registry[key] = subclass + return subclass + + return decorator + + +# in _init__.py will call something like +# def configure_loader_dependencies(binder): +# binder.bind(ModelRecordServiceBase, ApiDependencies.invoker.services.model_records, scope=singleton) +# binder.bind(InvokeAIAppConfig, ApiDependencies.invoker.services.configuration, scope=singleton) +# etc +# injector = Injector(configure_loader_dependencies) +# loader = injector.get(ModelFactory) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py new file mode 100644 index 0000000000..eb2d432aaa --- /dev/null +++ b/invokeai/backend/model_manager/load/load_default.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Default implementation of model loading in InvokeAI.""" + +import sys +from logging import Logger +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +from diffusers import ModelMixin +from diffusers.configuration_utils import ConfigMixin +from injector import inject + +from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager.convert_cache import ModelConvertCacheBase +from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.model_manager.ram_cache import ModelCacheBase, ModelLockerBase +from invokeai.backend.util.devices import choose_torch_device, torch_dtype + + +class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Load a diffusrs ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + # Diffusers doesn't provide typing info + return super().load_config(*args, **kwargs) # type: ignore + + +# TO DO: The loader is not thread safe! +class ModelLoader(ModelLoaderBase): + """Default implementation of ModelLoaderBase.""" + + @inject # can inject instances of each of the classes in the call signature + def __init__( + self, + app_config: InvokeAIAppConfig, + logger: Logger, + ram_cache: ModelCacheBase, + convert_cache: ModelConvertCacheBase, + ): + """Initialize the loader.""" + self._app_config = app_config + self._logger = logger + self._ram_cache = ram_cache + self._convert_cache = convert_cache + self._torch_dtype = torch_dtype(choose_torch_device()) + self._size: Optional[int] = None # model size + + def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: + """ + Return a model given its configuration. + + Given a model's configuration as returned by the ModelRecordConfigStore service, + return a LoadedModel object that can be used for inference. + + :param model config: Configuration record for this model + :param submodel_type: an ModelType enum indicating the portion of + the model to retrieve (e.g. ModelType.Vae) + """ + if model_config.type == "main" and not submodel_type: + raise InvalidModelConfigException("submodel_type is required when loading a main model") + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + if is_submodel_override: + submodel_type = None + + if not model_path.exists(): + raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}") + + model_path = self._convert_if_needed(model_config, model_path, submodel_type) + locker = self._load_if_needed(model_config, model_path, submodel_type) + return LoadedModel(config=model_config, locker=locker) + + # IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides + # and submodels!!!! + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, bool]: + model_base = self._app_config.models_path + return ((model_base / config.path).resolve(), False) + + def _convert_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> Path: + if not self._needs_conversion(config): + return model_path + + self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type)) + cache_path: Path = self._convert_cache.cache_path(config.key) + if cache_path.exists(): + return cache_path + + self._convert_model(model_path, cache_path) + return cache_path + + def _needs_conversion(self, config: AnyModelConfig) -> bool: + return False + + def _load_if_needed( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> ModelLockerBase: + # TO DO: This is not thread safe! + if self._ram_cache.exists(config.key, submodel_type): + return self._ram_cache.get(config.key, submodel_type) + + model_variant = getattr(config, "repo_variant", None) + self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type)) + + # This is where the model is actually loaded! + with skip_torch_weight_init(): + loaded_model = self._load_model(model_path, model_variant=model_variant, submodel_type=submodel_type) + + self._ram_cache.put( + config.key, + submodel_type=submodel_type, + model=loaded_model, + ) + + return self._ram_cache.get(config.key, submodel_type) + + def get_size_fs( + self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None + ) -> int: + """Get the size of the model on disk.""" + return calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=config.repo_variant if hasattr(config, "repo_variant") else None, + ) + + def _convert_model(self, model_path: Path, cache_path: Path) -> None: + raise NotImplementedError + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + raise NotImplementedError + + def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]: + return ConfigLoader.load_config(model_path, config_name=config_name) + + # TO DO: Add exception handling + def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type + if module in ["diffusers", "transformers"]: + res_type = sys.modules[module] + else: + res_type = sys.modules["diffusers"].pipelines + result: ModelMixin = getattr(res_type, class_name) + return result + + # TO DO: Add exception handling + def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: + if submodel_type: + config = self._load_diffusers_config(model_path, config_name="model_index.json") + module, class_name = config[submodel_type.value] + return self._hf_definition_to_type(module=module, class_name=class_name) + else: + config = self._load_diffusers_config(model_path, config_name="config.json") + class_name = config["_class_name"] + return self._hf_definition_to_type(module="diffusers", class_name=class_name) diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py new file mode 100644 index 0000000000..504829a427 --- /dev/null +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -0,0 +1,100 @@ +import gc +from typing import Optional + +import psutil +import torch +from typing_extensions import Self + +from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2 + +GB = 2**30 # 1 GB + + +class MemorySnapshot: + """A snapshot of RAM and VRAM usage. All values are in bytes.""" + + def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]): + """Initialize a MemorySnapshot. + + Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`. + + Args: + process_ram (int): CPU RAM used by the current process. + vram (Optional[int]): VRAM used by torch. + malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil. + """ + self.process_ram = process_ram + self.vram = vram + self.malloc_info = malloc_info + + @classmethod + def capture(cls, run_garbage_collector: bool = True) -> Self: + """Capture and return a MemorySnapshot. + + Note: This function has significant overhead, particularly if `run_garbage_collector == True`. + + Args: + run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM + usage. Defaults to True. + + Returns: + MemorySnapshot + """ + if run_garbage_collector: + gc.collect() + + # According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is + # supported on all platforms. + process_ram = psutil.Process().memory_info().rss + + if torch.cuda.is_available(): + vram = torch.cuda.memory_allocated() + else: + # TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have + # time to test it properly. + vram = None + + try: + malloc_info = LibcUtil().mallinfo2() # type: ignore + except (OSError, AttributeError): + # OSError: This is expected in environments that do not have the 'libc.so.6' shared library. + # AttributeError: This is expected in environments that have `libc.so.6` but do not have the `mallinfo2` (e.g. glibc < 2.33) + # TODO: Does `mallinfo` work? + malloc_info = None + + return cls(process_ram, vram, malloc_info) + + +def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str: + """Get a pretty string describing the difference between two `MemorySnapshot`s.""" + + def get_msg_line(prefix: str, val1: int, val2: int) -> str: + diff = val2 - val1 + return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n" + + msg = "" + + if snapshot_1 is None or snapshot_2 is None: + return msg + + msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) + + if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: + msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd) + + msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks) + + msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks) + + libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd + libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2) + + libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd + libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd + msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2) + + if snapshot_1.vram is not None and snapshot_2.vram is not None: + msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) + + return msg diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py new file mode 100644 index 0000000000..18407cbca2 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_util.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024 The InvokeAI Development Team +"""Various utility functions needed by the loader and caching system.""" + +import json +from pathlib import Path +from typing import Optional, Union + +import torch +from diffusers import DiffusionPipeline + +from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel + + +def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int: + """Get size of a model in memory in bytes.""" + if isinstance(model, DiffusionPipeline): + return _calc_pipeline_by_data(model) + elif isinstance(model, torch.nn.Module): + return _calc_model_by_data(model) + elif isinstance(model, IAIOnnxRuntimeModel): + return _calc_onnx_model_by_data(model) + else: + return 0 + + +def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int: + res = 0 + assert hasattr(pipeline, "components") + for submodel_key in pipeline.components.keys(): + submodel = getattr(pipeline, submodel_key) + if submodel is not None and isinstance(submodel, torch.nn.Module): + res += _calc_model_by_data(submodel) + return res + + +def _calc_model_by_data(model: torch.nn.Module) -> int: + mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) + mem: int = mem_params + mem_bufs # in bytes + return mem + + +def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int: + tensor_size = model.tensors.size() * 2 # The session doubles this + mem = tensor_size # in bytes + return mem + + +def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int: + """Estimate the size of a model on disk in bytes.""" + if subfolder is not None: + model_path = model_path / subfolder + + # this can happen when, for example, the safety checker is not downloaded. + if not model_path.exists(): + return 0 + + all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()] + + fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name} + bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name} + other_files = set(all_files) - fp16_files - bit8_files + + if variant is None: + files = other_files + elif variant == "fp16": + files = fp16_files + elif variant == "8bit": + files = bit8_files + else: + raise NotImplementedError(f"Unknown variant: {variant}") + + # try read from index if exists + index_postfix = ".index.json" + if variant is not None: + index_postfix = f".index.{variant}.json" + + for file in files: + if not file.name.endswith(index_postfix): + continue + try: + with open(model_path / file, "r") as f: + index_data = json.loads(f.read()) + return int(index_data["metadata"]["total_size"]) + except Exception: + pass + + # calculate files size if there is no index file + formats = [ + (".safetensors",), # safetensors + (".bin",), # torch + (".onnx", ".pb"), # onnx + (".msgpack",), # flax + (".ckpt",), # tf + (".h5",), # tf2 + ] + + for file_format in formats: + model_files = [f for f in files if f.suffix in file_format] + if len(model_files) == 0: + continue + + model_size = 0 + for model_file in model_files: + file_stats = (model_path / model_file).stat() + model_size += file_stats.st_size + return model_size + + return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu diff --git a/invokeai/backend/model_manager/load/optimizations.py b/invokeai/backend/model_manager/load/optimizations.py new file mode 100644 index 0000000000..a46d262175 --- /dev/null +++ b/invokeai/backend/model_manager/load/optimizations.py @@ -0,0 +1,30 @@ +from contextlib import contextmanager + +import torch + + +def _no_op(*args, **kwargs): + pass + + +@contextmanager +def skip_torch_weight_init(): + """A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.) + to skip weight initialization. + + By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular + distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is + completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager + monkey-patches common torch layers to skip the weight initialization step. + """ + torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding] + saved_functions = [m.reset_parameters for m in torch_modules] + + try: + for torch_module in torch_modules: + torch_module.reset_parameters = _no_op + + yield None + finally: + for torch_module, saved_function in zip(torch_modules, saved_functions, strict=True): + torch_module.reset_parameters = saved_function diff --git a/invokeai/backend/model_manager/load/ram_cache/__init__.py b/invokeai/backend/model_manager/load/ram_cache/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py new file mode 100644 index 0000000000..cd80d1e78b --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_base.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from logging import Logger +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase + + +@dataclass +class CacheStats(object): + """Data object to record statistics on cache hits/misses.""" + + hits: int = 0 # cache hits + misses: int = 0 # cache misses + high_watermark: int = 0 # amount of cache used + in_cache: int = 0 # number of models in cache + cleared: int = 0 # number of models cleared to make space + cache_size: int = 0 # total size of cache + loaded_model_sizes: Dict[str, int] = field(default_factory=dict) + + +@dataclass +class CacheRecord: + """Elements of the cache.""" + + key: str + model: AnyModel + size: int + _locks: int = 0 + + def lock(self) -> None: + """Lock this record.""" + self._locks += 1 + + def unlock(self) -> None: + """Unlock this record.""" + self._locks -= 1 + assert self._locks >= 0 + + @property + def locked(self) -> bool: + """Return true if record is locked.""" + return self._locks > 0 + + +class ModelCacheBase(ABC): + """Virtual base class for RAM model cache.""" + + @property + @abstractmethod + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + pass + + @property + @abstractmethod + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + pass + + @property + @abstractmethod + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + pass + + @abstractmethod + def offload_unlocked_models(self) -> None: + """Offload from VRAM any models not actively in use.""" + pass + + @abstractmethod + def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None: + """Move model into the indicated device.""" + pass + + @property + @abstractmethod + def logger(self) -> Logger: + """Return the logger used by the cache.""" + pass + + @abstractmethod + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + pass + + @abstractmethod + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + pass + + @abstractmethod + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model locker object using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + pass + + @abstractmethod + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + pass + + @abstractmethod + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + pass + + @abstractmethod + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + pass + + @abstractmethod + def print_cuda_stats(self) -> None: + """Log debugging information on CUDA usage.""" + pass diff --git a/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py new file mode 100644 index 0000000000..bd43e978c8 --- /dev/null +++ b/invokeai/backend/model_manager/load/ram_cache/ram_cache_default.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team +# TODO: Add Stalker's proper name to copyright +""" +Manage a RAM cache of diffusion/transformer models for fast switching. +They are moved between GPU VRAM and CPU RAM as necessary. If the cache +grows larger than a preset maximum, then the least recently used +model will be cleared and (re)loaded from disk when next needed. + +The cache returns context manager generators designed to load the +model into the GPU within the context, and unload outside the +context. Use like this: + + cache = ModelCache(max_cache_size=7.5) + with cache.get_model('runwayml/stable-diffusion-1-5') as SD1, + cache.get_model('stabilityai/stable-diffusion-2') as SD2: + do_something_in_GPU(SD1,SD2) + + +""" + +import math +import time +from contextlib import suppress +from logging import Logger +from typing import Any, Dict, List, Optional + +import torch + +from invokeai.app.services.model_records import UnknownModelException +from invokeai.backend.model_manager import SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModel, ModelLockerBase +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data +from invokeai.backend.model_manager.load.ram_cache.ram_cache_base import CacheRecord, CacheStats, ModelCacheBase +from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.logging import InvokeAILogger + +if choose_torch_device() == torch.device("mps"): + from torch import mps + +# Maximum size of the cache, in gigs +# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously +DEFAULT_MAX_CACHE_SIZE = 6.0 + +# amount of GPU memory to hold in reserve for use by generations (GB) +DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75 + +# actual size of a gig +GIG = 1073741824 + +# Size of a MB in bytes. +MB = 2**20 + + +class ModelCache(ModelCacheBase): + """Implementation of ModelCacheBase.""" + + def __init__( + self, + max_cache_size: float = DEFAULT_MAX_CACHE_SIZE, + max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE, + execution_device: torch.device = torch.device("cuda"), + storage_device: torch.device = torch.device("cpu"), + precision: torch.dtype = torch.float16, + sequential_offload: bool = False, + lazy_offloading: bool = True, + sha_chunksize: int = 16777216, + log_memory_usage: bool = False, + logger: Optional[Logger] = None, + ): + """ + Initialize the model RAM cache. + + :param max_cache_size: Maximum size of the RAM cache [6.0 GB] + :param execution_device: Torch device to load active model into [torch.device('cuda')] + :param storage_device: Torch device to save inactive model in [torch.device('cpu')] + :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded + :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially + :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache + operation, and the result will be logged (at debug level). There is a time cost to capturing the memory + snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's + behaviour. + """ + # allow lazy offloading only when vram cache enabled + self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 + self._precision: torch.dtype = precision + self._max_cache_size: float = max_cache_size + self._max_vram_cache_size: float = max_vram_cache_size + self._execution_device: torch.device = execution_device + self._storage_device: torch.device = storage_device + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._log_memory_usage = log_memory_usage + + # used for stats collection + self.stats = None + + self._cached_models: Dict[str, CacheRecord] = {} + self._cache_stack: List[str] = [] + + class ModelLocker(ModelLockerBase): + """Internal class that mediates movement in and out of GPU.""" + + def __init__(self, cache: ModelCacheBase, cache_entry: CacheRecord): + """ + Initialize the model locker. + + :param cache: The ModelCache object + :param cache_entry: The entry in the model cache + """ + self._cache = cache + self._cache_entry = cache_entry + + @property + def model(self) -> AnyModel: + """Return the model without moving it around.""" + return self._cache_entry.model + + def lock(self) -> Any: + """Move the model into the execution device (GPU) and lock it.""" + if not hasattr(self.model, "to"): + return self.model + + # NOTE that the model has to have the to() method in order for this code to move it into GPU! + self._cache_entry.lock() + + try: + if self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + + self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) + + self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") + self._cache.print_cuda_stats() + + except Exception: + self._cache_entry.unlock() + raise + return self.model + + def unlock(self) -> None: + """Call upon exit from context.""" + if not hasattr(self.model, "to"): + return + + self._cache_entry.unlock() + if not self._cache.lazy_offloading: + self._cache.offload_unlocked_models() + self._cache.print_cuda_stats() + + @property + def logger(self) -> Logger: + """Return the logger used by the cache.""" + return self._logger + + @property + def lazy_offloading(self) -> bool: + """Return true if the cache is configured to lazily offload models in VRAM.""" + return self._lazy_offloading + + @property + def storage_device(self) -> torch.device: + """Return the storage device (e.g. "CPU" for RAM).""" + return self._storage_device + + @property + def execution_device(self) -> torch.device: + """Return the exection device (e.g. "cuda" for VRAM).""" + return self._execution_device + + def cache_size(self) -> int: + """Get the total size of the models currently cached.""" + total = 0 + for cache_record in self._cached_models.values(): + total += cache_record.size + return total + + def exists( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> bool: + """Return true if the model identified by key and submodel_type is in the cache.""" + key = self._make_cache_key(key, submodel_type) + return key in self._cached_models + + def put( + self, + key: str, + model: AnyModel, + submodel_type: Optional[SubModelType] = None, + ) -> None: + """Store model under key and optional submodel_type.""" + key = self._make_cache_key(key, submodel_type) + assert key not in self._cached_models + + loaded_model_size = calc_model_size_by_data(model) + cache_record = CacheRecord(key, model, loaded_model_size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) + + def get( + self, + key: str, + submodel_type: Optional[SubModelType] = None, + ) -> ModelLockerBase: + """ + Retrieve model using key and optional submodel_type. + + This may return an UnknownModelException if the model is not in the cache. + """ + key = self._make_cache_key(key, submodel_type) + if key not in self._cached_models: + raise UnknownModelException + + # this moves the entry to the top (right end) of the stack + with suppress(Exception): + self._cache_stack.remove(key) + self._cache_stack.append(key) + cache_entry = self._cached_models[key] + return self.ModelLocker( + cache=self, + cache_entry=cache_entry, + ) + + def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: + if self._log_memory_usage: + return MemorySnapshot.capture() + return None + + def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + if submodel_type: + return f"{model_key}:{submodel_type.value}" + else: + return model_key + + def offload_unlocked_models(self) -> None: + """Move any unused models from VRAM.""" + reserved = self._max_vram_cache_size * GIG + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): + if vram_in_use <= reserved: + break + if not cache_entry.locked: + self.move_model_to_device(cache_entry, self.storage_device) + + vram_in_use = torch.cuda.memory_allocated() + self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") + + torch.cuda.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() + + # TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size + # for printing debugging messages. Revisit whether this is necessary + def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: + """Move model into the indicated device.""" + # These attributes are not in the base class but in derived classes + assert hasattr(cache_entry.model, "device") + assert hasattr(cache_entry.model, "to") + + source_device = cache_entry.model.device + + # Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support + # multi-GPU. + if torch.device(source_device).type == torch.device(target_device).type: + return + + start_model_to_time = time.time() + snapshot_before = self._capture_memory_snapshot() + cache_entry.model.to(target_device) + snapshot_after = self._capture_memory_snapshot() + end_model_to_time = time.time() + self.logger.debug( + f"Moved model '{cache_entry.key}' from {source_device} to" + f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n" + f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + if ( + snapshot_before is not None + and snapshot_after is not None + and snapshot_before.vram is not None + and snapshot_after.vram is not None + ): + vram_change = abs(snapshot_before.vram - snapshot_after.vram) + + # If the estimated model size does not match the change in VRAM, log a warning. + if not math.isclose( + vram_change, + cache_entry.size, + rel_tol=0.1, + abs_tol=10 * MB, + ): + self.logger.debug( + f"Moving model '{cache_entry.key}' from {source_device} to" + f" {target_device} caused an unexpected change in VRAM usage. The model's" + " estimated size may be incorrect. Estimated model size:" + f" {(cache_entry.size/GIG):.3f} GB.\n" + f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + ) + + def print_cuda_stats(self) -> None: + """Log CUDA diagnostics.""" + vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) + ram = "%4.2fG" % self.cache_size() + + cached_models = 0 + loaded_models = 0 + locked_models = 0 + for cache_record in self._cached_models.values(): + cached_models += 1 + assert hasattr(cache_record.model, "device") + if cache_record.model.device is self.storage_device: + loaded_models += 1 + if cache_record.locked: + locked_models += 1 + + self.logger.debug( + f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ =" + f" {cached_models}/{loaded_models}/{locked_models}" + ) + + def get_stats(self) -> CacheStats: + """Return cache hit/miss/size statistics.""" + raise NotImplementedError + + def make_room(self, size: int) -> None: + """Make enough room in the cache to accommodate a new model of indicated size.""" + raise NotImplementedError diff --git a/invokeai/backend/model_manager/load/vae.py b/invokeai/backend/model_manager/load/vae.py new file mode 100644 index 0000000000..a6cbe241e1 --- /dev/null +++ b/invokeai/backend/model_manager/load/vae.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for VAE model loading in InvokeAI.""" + +from pathlib import Path +from typing import Dict, Optional + +import torch + +from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers) +class VaeDiffusersModel(ModelLoader): + """Class to load VAE models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> Dict[str, torch.Tensor]: + if submodel_type is not None: + raise Exception("There are no submodels in VAEs") + vae_class = self._get_hf_load_class(model_path) + variant = model_variant.value if model_variant else "" + result: Dict[str, torch.Tensor] = vae_class.from_pretrained( + model_path, torch_dtype=self._torch_dtype, variant=variant + ) # type: ignore + return result diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/model_manager/onnx_runtime.py new file mode 100644 index 0000000000..f79fa01569 --- /dev/null +++ b/invokeai/backend/model_manager/onnx_runtime.py @@ -0,0 +1,216 @@ +# Copyright (c) 2024 The InvokeAI Development Team +import os +import sys +from pathlib import Path +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import onnx +from onnx import numpy_helper +from onnxruntime import InferenceSession, SessionOptions, get_available_providers + +ONNX_WEIGHTS_NAME = "model.onnx" + + +# NOTE FROM LS: This was copied from Stalker's original implementation. +# I have not yet gone through and fixed all the type hints +class IAIOnnxRuntimeModel: + class _tensor_access: + def __init__(self, model): # type: ignore + self.model = model + self.indexes = {} + for idx, obj in enumerate(self.model.proto.graph.initializer): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + value = self.model.proto.graph.initializer[self.indexes[key]] + return numpy_helper.to_array(value) + + def __setitem__(self, key: str, value: np.ndarray): # type: ignore + new_node = numpy_helper.from_array(value) + # set_external_data(new_node, location="in-memory-location") + new_node.name = key + # new_node.ClearField("raw_data") + del self.model.proto.graph.initializer[self.indexes[key]] + self.model.proto.graph.initializer.insert(self.indexes[key], new_node) + # self.model.data[key] = OrtValue.ortvalue_from_numpy(value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return self.indexes[key] in self.model.proto.graph.initializer + + def items(self) -> List[Tuple[str, Any]]: # fixme + raise NotImplementedError("tensor.items") + # return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + raise NotImplementedError("tensor.values") + # return [obj for obj in self.raw_proto] + + def size(self) -> int: + bytesSum = 0 + for node in self.model.proto.graph.initializer: + bytesSum += sys.getsizeof(node.raw_data) + return bytesSum + + class _access_helper: + def __init__(self, raw_proto): # type: ignore + self.indexes = {} + self.raw_proto = raw_proto + for idx, obj in enumerate(raw_proto): + self.indexes[obj.name] = idx + + def __getitem__(self, key: str): # type: ignore + return self.raw_proto[self.indexes[key]] + + def __setitem__(self, key: str, value): # type: ignore + index = self.indexes[key] + del self.raw_proto[index] + self.raw_proto.insert(index, value) + + # __delitem__ + + def __contains__(self, key: str) -> bool: + return key in self.indexes + + def items(self) -> List[Tuple[str, Any]]: + return [(obj.name, obj) for obj in self.raw_proto] + + def keys(self) -> List[str]: + return list(self.indexes.keys()) + + def values(self) -> List[Any]: # fixme + return list(self.raw_proto) + + def __init__(self, model_path: str, provider: Optional[str]): + self.path = model_path + self.session = None + self.provider = provider + """ + self.data_path = self.path + "_data" + if not os.path.exists(self.data_path): + print(f"Moving model tensors to separate file: {self.data_path}") + tmp_proto = onnx.load(model_path, load_external_data=True) + onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False) + del tmp_proto + gc.collect() + + self.proto = onnx.load(model_path, load_external_data=False) + """ + + self.proto = onnx.load(model_path, load_external_data=True) + # self.data = dict() + # for tensor in self.proto.graph.initializer: + # name = tensor.name + + # if tensor.HasField("raw_data"): + # npt = numpy_helper.to_array(tensor) + # orv = OrtValue.ortvalue_from_numpy(npt) + # # self.data[name] = orv + # # set_external_data(tensor, location="in-memory-location") + # tensor.name = name + # # tensor.ClearField("raw_data") + + self.nodes = self._access_helper(self.proto.graph.node) # type: ignore + # self.initializers = self._access_helper(self.proto.graph.initializer) + # print(self.proto.graph.input) + # print(self.proto.graph.initializer) + + self.tensors = self._tensor_access(self) # type: ignore + + # TODO: integrate with model manager/cache + def create_session(self, height=None, width=None): + if self.session is None or self.session_width != width or self.session_height != height: + # onnx.save(self.proto, "tmp.onnx") + # onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + # TODO: something to be able to get weight when they already moved outside of model proto + # (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + # self._external_data.update(**external_data) + # sess.add_external_initializers(list(self.data.keys()), list(self.data.values())) + # sess.enable_profiling = True + + # sess.intra_op_num_threads = 1 + # sess.inter_op_num_threads = 1 + # sess.execution_mode = ExecutionMode.ORT_SEQUENTIAL + # sess.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + # sess.enable_cpu_mem_arena = True + # sess.enable_mem_pattern = True + # sess.add_session_config_entry("session.intra_op.use_xnnpack_threadpool", "1") ########### It's the key code + self.session_height = height + self.session_width = width + if height and width: + sess.add_free_dimension_override_by_name("unet_sample_batch", 2) + sess.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess.add_free_dimension_override_by_name("unet_hidden_batch", 2) + sess.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + sess.add_free_dimension_override_by_name("unet_sample_height", self.session_height) + sess.add_free_dimension_override_by_name("unet_sample_width", self.session_width) + sess.add_free_dimension_override_by_name("unet_time_batch", 1) + providers = [] + if self.provider: + providers.append(self.provider) + else: + providers = get_available_providers() + if "TensorrtExecutionProvider" in providers: + providers.remove("TensorrtExecutionProvider") + try: + self.session = InferenceSession(self.proto.SerializeToString(), providers=providers, sess_options=sess) + except Exception as e: + raise e + # self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + # self.io_binding = self.session.io_binding() + + def release_session(self): + self.session = None + import gc + + gc.collect() + return + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + # output_names = self.session.get_outputs() + # for k in inputs: + # self.io_binding.bind_cpu_input(k, inputs[k]) + # for name in output_names: + # self.io_binding.bind_output(name.name) + # self.session.run_with_iobinding(self.io_binding, None) + # return self.io_binding.copy_outputs_to_cpu() + return self.session.run(None, inputs) + + # compatability with diffusers load code + @classmethod + def from_pretrained( + cls, + model_id: Union[str, Path], + subfolder: Optional[Union[str, Path]] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs: Any, + ) -> Any: # fixme + file_name = file_name or ONNX_WEIGHTS_NAME + + if os.path.isdir(model_id): + model_path = model_id + if subfolder is not None: + model_path = os.path.join(model_path, subfolder) + model_path = os.path.join(model_path, file_name) + + else: + model_path = model_id + + # load model from local directory + if not os.path.isfile(model_path): + raise Exception(f"Model not found: {model_path}") + + # TODO: session options + return cls(str(model_path), provider=provider) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ba3ac3dd0c..9fd118b782 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -18,9 +18,9 @@ from .config import ( InvalidModelConfigException, ModelConfigFactory, ModelFormat, + ModelRepoVariant, ModelType, ModelVariantType, - ModelRepoVariant, SchedulerPredictionType, ) from .hash import FastModelHash @@ -483,8 +483,8 @@ class FolderProbeBase(ProbeBase): def get_repo_variant(self) -> ModelRepoVariant: # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob('**/*.safetensors')) - weight_files.extend(list(self.model_path.glob('**/*.bin'))) + weight_files = list(self.model_path.glob("**/*.safetensors")) + weight_files.extend(list(self.model_path.glob("**/*.bin"))) for x in weight_files: if ".fp16" in x.suffixes: return ModelRepoVariant.FP16 @@ -496,6 +496,7 @@ class FolderProbeBase(ProbeBase): return ModelRepoVariant.ONNX return ModelRepoVariant.DEFAULT + class PipelineFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: with open(self.model_path / "unet" / "config.json", "r") as file: @@ -540,7 +541,6 @@ class PipelineFolderProbe(FolderProbeBase): except Exception: pass return ModelVariantType.Normal - class VaeFolderProbe(FolderProbeBase): diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 415559a64c..aacae06a8b 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -21,9 +21,10 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat base_type = probe.get_base_type() assert base_type == expected_type repo_variant = probe.get_repo_variant() - assert repo_variant == 'default' + assert repo_variant == "default" + def test_repo_variant(datadir: Path): probe = VaeFolderProbe(datadir / "vae" / "taesdxl-fp16") repo_variant = probe.get_repo_variant() - assert repo_variant == 'fp16' + assert repo_variant == "fp16"