diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index dd126a322d..06066dd6b1 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -10,7 +10,7 @@ from .generator import ( Img2Img, Inpaint ) -from .model_management import ModelManager, SDModelComponent +from .model_management import ModelManager from .safety_checker import SafetyChecker from .args import Args from .globals import Globals diff --git a/invokeai/backend/globals.py b/invokeai/backend/globals.py index c5417e03db..37a59b1135 100644 --- a/invokeai/backend/globals.py +++ b/invokeai/backend/globals.py @@ -93,6 +93,8 @@ def global_converted_ckpts_dir() -> Path: def global_set_root(root_dir: Union[str, Path]): Globals.root = root_dir +def global_resolve_path(path: Union[str,Path]): + return Path(Globals.root,path).resolve() def global_cache_dir(subdir: Union[str, Path] = "") -> Path: """ diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index 64ef4e9ee9..07b567ce7a 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -1,11 +1,5 @@ """ Initialization file for invokeai.backend.model_management """ -from .convert_ckpt_to_diffusers import ( - convert_ckpt_to_diffusers, - load_pipeline_from_original_stable_diffusion_ckpt, -) -from .model_manager import ModelManager,SDModelComponent - - - +from .model_manager import ModelManager +from .model_cache import ModelCache, ModelStatus diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 8aec5a01d9..aaa69da7f3 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -29,6 +29,7 @@ import invokeai.backend.util.logging as logger from invokeai.backend.globals import global_cache_dir, global_config_dir from .model_manager import ModelManager, SDLegacyType +from .model_cache import ModelCache try: from omegaconf import OmegaConf @@ -1100,7 +1101,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt( if Path(checkpoint_path).suffix == '.ckpt': if scan_needed: - ModelManager.scan_model(checkpoint_path,checkpoint_path) + ModelCache.scan_model(checkpoint_path,checkpoint_path) checkpoint = torch.load(checkpoint_path) else: checkpoint = load_file(checkpoint_path) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 02448c59a6..95b4e165f6 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -24,10 +24,10 @@ from collections.abc import Generator from collections import Counter from enum import Enum from pathlib import Path -from typing import Sequence, Union +from typing import Sequence, Union, Tuple, types import torch -from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel +from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel from diffusers import logging as diffusers_logging from diffusers.pipelines.stable_diffusion.safety_checker import \ StableDiffusionSafetyChecker @@ -40,7 +40,6 @@ from transformers import logging as transformers_logging import invokeai.backend.util.logging as logger from ..globals import global_cache_dir from ..stable_diffusion import StableDiffusionGeneratorPipeline -from . import load_pipeline_from_original_stable_diffusion_ckpt MAX_MODELS = 4 @@ -54,9 +53,17 @@ class SDModelType(Enum): scheduler=SchedulerMixin safety_checker=StableDiffusionSafetyChecker feature_extractor=CLIPFeatureExtractor + +class ModelStatus(Enum): + unknown='unknown' + not_loaded='not loaded' + in_ram='cached' + in_vram='in gpu' + active='locked in gpu' # The list of model classes we know how to fetch, for typechecking ModelClass = Union[tuple([x.value for x in SDModelType])] +DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel) # Legacy information needed to load a legacy checkpoint file class LegacyInfo(BaseModel): @@ -81,6 +88,7 @@ class ModelCache(object): sequential_offload: bool=False, lazy_offloading: bool=True, sha_chunksize: int = 16777216, + logger: types.ModuleType = logger ): ''' :param max_models: Maximum number of models to cache in CPU RAM [4] @@ -100,9 +108,11 @@ class ModelCache(object): self.execution_device: torch.device=execution_device self.storage_device: torch.device=storage_device self.sha_chunksize=sha_chunksize + self.logger = logger self.loaded_models: set = set() # set of model keys loaded in GPU self.locked_models: Counter = Counter() # set of model keys locked in GPU + @contextlib.contextmanager def get_model( self, @@ -112,6 +122,7 @@ class ModelCache(object): submodel: SDModelType=None, revision: str=None, legacy_info: LegacyInfo=None, + attach_model_part: Tuple[SDModelType, str] = (None,None), gpu_load: bool=True, )->Generator[ModelClass, None, None]: ''' @@ -122,10 +133,28 @@ class ModelCache(object): with cache.get_model('stabilityai/stable-diffusion-2') as SD2: do_something_with_the_model(SD2) + You can fetch an individual part of a diffusers model by passing the submodel + argument: + + vae_context = cache.get_model( + 'stabilityai/sd-stable-diffusion-2', + submodel=SDModelType.vae + ) + + Vice versa, you can load and attach an external submodel to a diffusers model + before returning it by passing the attach_submodel argument. This only works with + diffusers models: + + pipeline_context = cache.get_model( + 'runwayml/stable-diffusion-v1-5', + attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse') + ) + The model will be locked into GPU VRAM for the duration of the context. :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae + :param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id) :param revision: model revision :param model_class: class of model to return :param gpu_load: load the model into GPU [default True] @@ -151,6 +180,8 @@ class ModelCache(object): revision=revision, legacy_info=legacy_info, ) + if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]: + self.attach_part(model,*attach_model_part) self.stack.append(key) # add to LRU cache self.models[key]=model # keep copy of model in dict @@ -163,7 +194,7 @@ class ModelCache(object): self.locked_models[key] += 1 if self.lazy_offloading: self._offload_unlocked_models() - logger.debug(f'Loading {key} into {self.execution_device}') + self.logger.debug(f'Loading {key} into {self.execution_device}') model.to(self.execution_device) # move into GPU self._print_cuda_stats() yield model @@ -181,25 +212,59 @@ class ModelCache(object): self.loaded_models.remove(key) yield model - def _offload_unlocked_models(self): - to_offload = set() - for key in self.loaded_models: - if key not in self.locked_models or self.locked_models[key] == 0: - logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}') - to_offload.add(key) - for key in to_offload: - self.models[key].to(self.storage_device) - self.loaded_models.remove(key) + def attach_part(self, + diffusers_model: StableDiffusionPipeline, + part_type: SDModelType, + part_id: str + ): + ''' + Attach a diffusers model part to a diffusers model. This can be + used to replace the VAE, tokenizer, textencoder, unet, etc. + :param diffuser_model: The diffusers model to attach the part to. + :param part_type: An SD ModelType indicating the part + :param part_id: A HF repo_id for the part + ''' + part_key = part_type.name + part_class = part_type.value + part = self._load_diffusers_from_storage( + part_id, + model_class=part_class, + ) + part.to(diffusers_model.device) + setattr(diffusers_model,part_key,part) + self.logger.debug(f'Attached {part_key} {part_id}') + + def status(self, + repo_id_or_path: Union[str,Path], + model_type: SDModelType=SDModelType.diffusion_pipeline, + revision: str=None, + subfolder: Path=None, + )->ModelStatus: + key = self._model_key( + repo_id_or_path, + model_type.value, + revision, + subfolder) + if key not in self.models: + return ModelStatus.not_loaded + if key in self.loaded_models: + if self.locked_models[key] > 0: + return ModelStatus.active + else: + return ModelStatus.in_vram + else: + return ModelStatus.in_ram def model_hash(self, repo_id_or_path: Union[str,Path], - revision: str=None)->str: + revision: str="main")->str: ''' Given the HF repo id or path to a model on disk, returns a unique hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs :param repo_id_or_path: repo_id string or Path to model file/directory on disk. :param revision: optional revision string (if fetching a HF repo_id) ''' + revision = revision or "main" if self.is_legacy_ckpt(repo_id_or_path): return self._legacy_model_hash(repo_id_or_path) elif Path(repo_id_or_path).is_dir(): @@ -211,96 +276,6 @@ class ModelCache(object): "Return the current number of models cached." return len(self.models) - @staticmethod - def _model_key(path,model_class,revision,subfolder)->str: - return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')]) - - def _has_cuda(self)->bool: - return self.execution_device.type == 'cuda' - - def _print_cuda_stats(self): - vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9) - loaded_models = len(self.loaded_models) - locked_models = len([x for x in self.locked_models if self.locked_models[x]>0]) - logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}") - - def _make_cache_room(self): - models_in_ram = len(self.models) - while models_in_ram >= self.max_models: - if least_recently_used_key := self.stack.pop(0): - logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') - del self.models[least_recently_used_key] - models_in_ram = len(self.models) - gc.collect() - - @property - def current_model(self)->ModelClass: - ''' - Returns current model. - ''' - return self.models[self._current_model_key] - - @property - def _current_model_key(self)->str: - ''' - Returns key of currently loaded model. - ''' - return self.stack[-1] - - def _load_model_from_storage( - self, - repo_id_or_path: Union[str,Path], - subfolder: Path=None, - revision: str=None, - model_class: ModelClass=StableDiffusionGeneratorPipeline, - legacy_info: LegacyInfo=None, - )->ModelClass: - ''' - Load and return a HuggingFace model. - :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model - :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" - :param revision: model revision - :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline - :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt - ''' - # silence transformer and diffuser warnings - with SilenceWarnings(): - if self.is_legacy_ckpt(repo_id_or_path): - model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) - else: - model = self._load_diffusers_from_storage( - repo_id_or_path, - subfolder, - revision, - model_class, - ) - if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): - model.enable_offload_submodels(self.execution_device) - elif hasattr(model,'to'): - model.to(self.execution_device) - return model - - def _load_diffusers_from_storage( - self, - repo_id_or_path: Union[str,Path], - subfolder: Path=None, - revision: str=None, - model_class: ModelClass=StableDiffusionGeneratorPipeline, - )->ModelClass: - ''' - Load and return a HuggingFace model using from_pretrained(). - :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model - :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" - :param revision: model revision - :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline - ''' - return model_class.from_pretrained( - repo_id_or_path, - revision=revision, - subfolder=subfolder or '.', - cache_dir=global_cache_dir('hub'), - ) - @classmethod def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool: ''' @@ -327,6 +302,106 @@ class ModelCache(object): else: logger.debug("Model scanned ok") + @staticmethod + def _model_key(path,model_class,revision,subfolder)->str: + return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')]) + + def _has_cuda(self)->bool: + return self.execution_device.type == 'cuda' + + def _print_cuda_stats(self): + vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9) + loaded_models = len(self.loaded_models) + locked_models = len([x for x in self.locked_models if self.locked_models[x]>0]) + logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}") + + def _make_cache_room(self): + models_in_ram = len(self.models) + while models_in_ram >= self.max_models: + if least_recently_used_key := self.stack.pop(0): + logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') + del self.models[least_recently_used_key] + models_in_ram = len(self.models) + gc.collect() + + def _offload_unlocked_models(self): + to_offload = set() + for key in self.loaded_models: + if key not in self.locked_models or self.locked_models[key] == 0: + self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}') + to_offload.add(key) + for key in to_offload: + self.models[key].to(self.storage_device) + self.loaded_models.remove(key) + + def _load_model_from_storage( + self, + repo_id_or_path: Union[str,Path], + subfolder: Path=None, + revision: str=None, + model_class: ModelClass=StableDiffusionGeneratorPipeline, + legacy_info: LegacyInfo=None, + )->ModelClass: + ''' + Load and return a HuggingFace model. + :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model + :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" + :param revision: model revision + :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline + :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt + ''' + if self.is_legacy_ckpt(repo_id_or_path): + model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info) + else: + model = self._load_diffusers_from_storage( + repo_id_or_path, + subfolder, + revision, + model_class, + ) + if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): + model.enable_offload_submodels(self.execution_device) + return model + + def _load_diffusers_from_storage( + self, + repo_id_or_path: Union[str,Path], + subfolder: Path=None, + revision: str=None, + model_class: ModelClass=StableDiffusionGeneratorPipeline, + )->ModelClass: + ''' + Load and return a HuggingFace model using from_pretrained(). + :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model + :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" + :param revision: model revision + :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline + ''' + self.logger.info(f'Loading model {repo_id_or_path}') + revisions = [revision] if revision \ + else ['fp16','main'] if self.precision==torch.float16 \ + else ['main'] + extra_args = {'precision': self.precision} \ + if model_class in DiffusionClasses \ + else {} + + # silence transformer and diffuser warnings + with SilenceWarnings(): + for rev in revisions: + try: + model = model_class.from_pretrained( + repo_id_or_path, + revision=rev, + subfolder=subfolder or '.', + cache_dir=global_cache_dir('hub'), + **extra_args, + ) + self.logger.debug(f'Found revision {rev}') + break + except OSError: + pass + return model + def _load_ckpt_from_storage(self, ckpt_path: Union[str,Path], legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline: @@ -336,13 +411,17 @@ class ModelCache(object): :param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required ''' assert legacy_info is not None - pipeline = load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path=ckpt_path, - original_config_file=legacy_info.config_file, - vae_path=legacy_info.vae_file, - return_generator_pipeline=True, - precision=self.precision, - ) + + # deferred loading to avoid circular import errors + from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt + with SilenceWarnings(): + pipeline = load_pipeline_from_original_stable_diffusion_ckpt( + checkpoint_path=ckpt_path, + original_config_file=legacy_info.config_file, + vae_path=legacy_info.vae_file, + return_generator_pipeline=True, + precision=self.precision, + ) return pipeline def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: @@ -399,6 +478,7 @@ class ModelCache(object): raise KeyError(f"Revision '{revision}' not found in {repo_id}") return desired_revisions[0].target_commit + class SilenceWarnings(object): def __init__(self): self.transformers_verbosity = transformers_logging.get_verbosity() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 9ba1e8779c..3977ac0ed7 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -1,56 +1,88 @@ -""" -Manage a cache of Stable Diffusion model files for fast switching. -They are moved between GPU and CPU as necessary. If CPU memory falls -below a preset minimum, the least recently used model will be -cleared and loaded from disk when next needed. +"""This module manages the InvokeAI `models.yaml` file, mapping +symbolic diffusers model names to the paths and repo_ids used +by the underlying `from_pretrained()` call. + +For fetching models, use manager.get_model('symbolic name'). This will +return a SDModelInfo object that contains the following attributes: + + * context -- a context manager Generator that loads and locks the + model into GPU VRAM and returns the model for use. + See below for usage. + * name -- symbolic name of the model + * hash -- unique hash for the model + * location -- path or repo_id of the model + * revision -- revision of the model if coming from a repo id, + e.g. 'fp16' + * precision -- torch precision of the model + * status -- a ModelStatus enum corresponding to one of + 'not_loaded', 'in_ram', 'in_vram' or 'active' + +Typical usage: + + from invokeai.backend import ModelManager + manager = ModelManager(config_path='./configs/models.yaml',max_models=4) + model_info = manager.get_model('stable-diffusion-1.5') + with model_info.context as my_model: + my_model.latents_from_embeddings(...) + +The manager uses the underlying ModelCache class to keep +frequently-used models in RAM and move them into GPU as needed for +generation operations. The ModelCache object can be accessed using +the manager's "cache" attribute. + +Other methods provided by ModelManager support importing, editing, +converting and deleting models. """ from __future__ import annotations -import contextlib -import gc -import hashlib import os import re -import sys import textwrap -import time -import warnings +from dataclasses import dataclass from enum import Enum, auto from pathlib import Path -from shutil import move, rmtree -from typing import Any, Optional, Union, Callable, types +from shutil import rmtree +from typing import Union, Callable, types import safetensors import safetensors.torch import torch -import transformers import invokeai.backend.util.logging as logger -from diffusers import ( - AutoencoderKL, - UNet2DConditionModel, - SchedulerMixin, - logging as dlogging, -) from huggingface_hub import scan_cache_dir from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from picklescan.scanner import scan_file_path -from invokeai.backend.globals import Globals, global_cache_dir +from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path +from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo -from transformers import ( - CLIPTextModel, - CLIPTokenizer, - CLIPFeatureExtractor, -) -from diffusers.pipelines.stable_diffusion.safety_checker import ( - StableDiffusionSafetyChecker, - ) -from ..stable_diffusion import ( - StableDiffusionGeneratorPipeline, -) -from ..util import CUDA_DEVICE, ask_user, download_with_resume +from ..util import CUDA_DEVICE +# wanted to use pydantic here, but Generator objects not supported +@dataclass +class SDModelInfo(): + context: Generator[ModelClass, None, None] + name: str + hash: str + location: Union[Path,str] + precision: torch.dtype + subfolder: Path = None + revision: str = None + _cache: ModelCache = None + + + def status(self)->ModelStatus: + '''Return load status of this model as a model_cache.ModelStatus enum''' + if not self._cache: + return ModelStatus.unknown + return self._cache.status( + self.location, + revision = self.revision, + subfolder = self.subfolder + ) + +class InvalidModelError(Exception): + "Raised when an invalid model is requested" + pass class SDLegacyType(Enum): V1 = auto() @@ -60,54 +92,39 @@ class SDLegacyType(Enum): V2_v = auto() UNKNOWN = auto() -class SDModelComponent(Enum): - vae="vae" - text_encoder="text_encoder" - tokenizer="tokenizer" - unet="unet" - scheduler="scheduler" - safety_checker="safety_checker" - feature_extractor="feature_extractor" - DEFAULT_MAX_MODELS = 2 class ModelManager(object): """ - Model manager handles loading, caching, importing, deleting, converting, and editing models. + High-level interface to model management. """ logger: types.ModuleType = logger def __init__( - self, - config: OmegaConf | Path, - device_type: torch.device = CUDA_DEVICE, - precision: str = "float16", - max_loaded_models=DEFAULT_MAX_MODELS, - sequential_offload=False, - embedding_path: Path = None, - logger: types.ModuleType = logger, + self, + config_path: Path, + device_type: torch.device = CUDA_DEVICE, + precision: torch.dtype = torch.float16, + max_models=DEFAULT_MAX_MODELS, + sequential_offload=False, + logger: types.ModuleType = logger, ): """ - Initialize with the path to the models.yaml config file or - an initialized OmegaConf dictionary. Optional parameters - are the torch device type, precision, max_loaded_models, + Initialize with the path to the models.yaml config file. + Optional parameters are the torch device type, precision, max_models, and sequential_offload boolean. Note that the default device type and precision are set up for a CUDA system running at half precision. """ - # prevent nasty-looking CLIP log message - transformers.logging.set_verbosity_error() - if not isinstance(config, DictConfig): - config = OmegaConf.load(config) - self.config = config - self.precision = precision - self.device = torch.device(device_type) - self.max_loaded_models = max_loaded_models - self.models = {} - self.stack = [] # this is an LRU FIFO - self.current_model = None - self.sequential_offload = sequential_offload - self.embedding_path = embedding_path + self.config_path = config_path + self.config = OmegaConf.load(self.config_path) + self.cache = ModelCache( + max_models=max_models, + execution_device = device_type, + precision = precision, + sequential_offload = sequential_offload, + logger = logger, + ) self.logger = logger def valid_model(self, model_name: str) -> bool: @@ -117,138 +134,69 @@ class ModelManager(object): """ return model_name in self.config - def get_model(self, model_name: str = None) -> dict: - """Given a model named identified in models.yaml, return a dict - containing the model object and some of its key features. If - in RAM will load into GPU VRAM. If on disk, will load from - there. - The dict has the following keys: - 'model': The StableDiffusionGeneratorPipeline object - 'model_name': The name of the model in models.yaml - 'width': The width of images trained by this model - 'height': The height of images trained by this model - 'hash': A unique hash of this model's files on disk. + def get_model(self, + model_name: str = None, + submodel: SDModelType=None, + ) -> SDModelInfo: + """Given a model named identified in models.yaml, return + an SDModelInfo object describing it. + :param model_name: symbolic name of the model in models.yaml + :param submodel: an SDModelType enum indicating the portion of + the model to retrieve (e.g. SDModelType.vae) """ if not model_name: - return ( - self.get_model(self.current_model) - if self.current_model - else self.get_model(self.default_model()) - ) + model_name = self.default_model() if not self.valid_model(model_name): - self.logger.error( + raise InvalidModelError( f'"{model_name}" is not a known model name. Please check your models.yaml file' ) - return self.current_model - if self.current_model != model_name: - if model_name not in self.models: # make room for a new one - self._make_cache_room() - self.offload_model(self.current_model) - - if model_name in self.models: - requested_model = self.models[model_name]["model"] - self.logger.info(f"Retrieving model {model_name} from system RAM cache") - requested_model.ready() - width = self.models[model_name]["width"] - height = self.models[model_name]["height"] - hash = self.models[model_name]["hash"] - - else: # we're about to load a new model, so potentially offload the least recently used one - requested_model, width, height, hash = self._load_model(model_name) - self.models[model_name] = { - "model_name": model_name, - "model": requested_model, - "width": width, - "height": height, - "hash": hash, - } - - self.current_model = model_name - self._push_newest_model(model_name) - return { - "model_name": model_name, - "model": requested_model, - "width": width, - "height": height, - "hash": hash, - } - - def get_model_vae(self, model_name: str=None)->AutoencoderKL: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned VAE as an - AutoencoderKL object. If no model name is provided, return the - vae from the model currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.vae) - - def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned CLIPTokenizer. If no - model name is provided, return the tokenizer from the model - currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.tokenizer) - - def get_model_unet(self, model_name: str=None)->UNet2DConditionModel: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned UNet2DConditionModel. If no model - name is provided, return the UNet from the model - currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.unet) - - def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned CLIPTextModel. If no - model name is provided, return the text encoder from the model - currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.text_encoder) - - def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned CLIPFeatureExtractor. If no - model name is provided, return the text encoder from the model - currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.feature_extractor) - - def get_model_scheduler(self, model_name: str=None)->SchedulerMixin: - """Given a model name identified in models.yaml, load the model into - GPU if necessary and return its assigned scheduler. If no - model name is provided, return the text encoder from the model - currently in the GPU. - """ - return self._get_sub_model(model_name, SDModelComponent.scheduler) - - def _get_sub_model( - self, - model_name: str=None, - model_part: SDModelComponent=SDModelComponent.vae, - ) -> Union[ - AutoencoderKL, - CLIPTokenizer, - CLIPFeatureExtractor, - UNet2DConditionModel, - CLIPTextModel, - StableDiffusionSafetyChecker, - ]: - """Given a model name identified in models.yaml, and the part of the - model you wish to retrieve, return that part. Parts are in an Enum - class named SDModelComponent, and consist of: - SDModelComponent.vae - SDModelComponent.text_encoder - SDModelComponent.tokenizer - SDModelComponent.unet - SDModelComponent.scheduler - SDModelComponent.safety_checker - SDModelComponent.feature_extractor - """ - model_dict = self.get_model(model_name) - model = model_dict["model"] - return getattr(model, model_part.value) + # get the required loading info out of the config file + mconfig = self.config[model_name] + format = mconfig.get('format','diffusers') + legacy = None + if format=='ckpt': + location = global_resolve_path(mconfig.weights) + legacy = LegacyInfo( + config_file = global_resolve_path(mconfig.config), + ) + if mconfig.get('vae'): + legacy.vae_file = global_resolve_path(mconfig.vae) + elif format=='diffusers': + location = mconfig.repo_id + revision = mconfig.get('revision') + else: + raise InvalidModelError( + f'"{model_name}" has an unknown format {format}' + ) + + subfolder = mconfig.get('subfolder') + hash = self.cache.model_hash(location,revision) + vae = (None,None) + try: + vae_id = mconfig.vae.repo_id + vae = (SDModelType.vae,vae_id) + except Exception: + pass + model_context = self.cache.get_model( + location, + revision = revision, + subfolder = subfolder, + legacy_info = legacy, + submodel = submodel, + attach_model_part=vae, + ) + return SDModelInfo( + context = model_context, + name = model_name, + hash = hash, + location = location, + revision = revision, + precision = self.cache.precision, + subfolder = subfolder, + _cache = self.cache + ) def default_model(self) -> str | None: """ @@ -324,19 +272,19 @@ class ModelManager(object): format = stanza.get("format", "ckpt") # Determine Format # Common Attribs + status = self.cache.status( + stanza.get('weights') or stanza.get('repo_id'), + revision=stanza.get('revision'), + subfolder=stanza.get('subfolder') + ) description = stanza.get("description", None) - if self.current_model == name: - status = "active" - elif name in self.models: - status = "cached" - else: - status = "not loaded" models[name].update( description=description, format=format, - status=status, + status=status.value ) + # Checkpoint Config Parse if format == "ckpt": models[name].update( @@ -373,7 +321,7 @@ class ModelManager(object): for name in models: if models[name]["format"] == "vae": continue - line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["format"]:10s} {models[name]["description"]}' + line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}' if models[name]["status"] == "active": line = f"\033[1m{line}\033[0m" print(line) @@ -441,233 +389,6 @@ class ModelManager(object): if clobber: self._invalidate_cached_model(model_name) - def _load_model(self, model_name: str): - """Load and initialize the model from configuration variables passed at object creation time""" - if model_name not in self.config: - self.logger.error( - f'"{model_name}" is not a known model name. Please check your models.yaml file' - ) - return - - mconfig = self.config[model_name] - - # for usage statistics - if self._has_cuda(): - torch.cuda.reset_peak_memory_stats() - torch.cuda.empty_cache() - - tic = time.time() - - # this does the work - model_format = mconfig.get("format", "ckpt") - if model_format == "ckpt": - weights = mconfig.weights - self.logger.info(f"Loading {model_name} from {weights}") - model, width, height, model_hash = self._load_ckpt_model( - model_name, mconfig - ) - elif model_format == "diffusers": - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - model, width, height, model_hash = self._load_diffusers_model(mconfig) - else: - raise NotImplementedError( - f"Unknown model format {model_name}: {model_format}" - ) - self._add_embeddings_to_model(model) - - # usage statistics - toc = time.time() - self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic)) - if self._has_cuda(): - self.logger.info( - "Max VRAM used to load the model: "+ - "%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9) - ) - self.logger.info( - "Current VRAM usage: "+ - "%4.2fG" % (torch.cuda.memory_allocated() / 1e9) - ) - return model, width, height, model_hash - - def _load_diffusers_model(self, mconfig): - name_or_path = self.model_name_or_path(mconfig) - using_fp16 = self.precision == "float16" - - self.logger.info(f"Loading diffusers model from {name_or_path}") - if using_fp16: - self.logger.debug("Using faster float16 precision") - else: - self.logger.debug("Using more accurate float32 precision") - - # TODO: scan weights maybe? - pipeline_args: dict[str, Any] = dict( - safety_checker=None, local_files_only=not Globals.internet_available - ) - if "vae" in mconfig and mconfig["vae"] is not None: - if vae := self._load_vae(mconfig["vae"]): - pipeline_args.update(vae=vae) - if not isinstance(name_or_path, Path): - pipeline_args.update(cache_dir=global_cache_dir("hub")) - if using_fp16: - pipeline_args.update(torch_dtype=torch.float16) - fp_args_list = [{"revision": "fp16"}, {}] - else: - fp_args_list = [{}] - - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - - pipeline = None - for fp_args in fp_args_list: - try: - pipeline = StableDiffusionGeneratorPipeline.from_pretrained( - name_or_path, - **pipeline_args, - **fp_args, - ) - except OSError as e: - if str(e).startswith("fp16 is not a valid"): - pass - else: - self.logger.error( - f"An unexpected error occurred while downloading the model: {e})" - ) - if pipeline: - break - - dlogging.set_verbosity(verbosity) - assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded') - - if self.sequential_offload: - pipeline.enable_offload_submodels(self.device) - else: - pipeline.to(self.device) - - model_hash = self._diffuser_sha256(name_or_path) - - # square images??? - width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor - height = width - self.logger.debug(f"Default image dimensions = {width} x {height}") - - return pipeline, width, height, model_hash - - def _load_ckpt_model(self, model_name, mconfig): - config = mconfig.config - weights = mconfig.weights - vae = mconfig.get("vae") - width = mconfig.width - height = mconfig.height - - if not os.path.isabs(config): - config = os.path.join(Globals.root, config) - if not os.path.isabs(weights): - weights = os.path.normpath(os.path.join(Globals.root, weights)) - - # Convert to diffusers and return a diffusers pipeline - self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...") - - from . import load_pipeline_from_original_stable_diffusion_ckpt - - try: - if self.list_models()[self.current_model]["status"] == "active": - self.offload_model(self.current_model) - except Exception: - pass - - vae_path = None - if vae: - vae_path = ( - vae - if os.path.isabs(vae) - else os.path.normpath(os.path.join(Globals.root, vae)) - ) - if self._has_cuda(): - torch.cuda.empty_cache() - pipeline = load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path=weights, - original_config_file=config, - vae_path=vae_path, - return_generator_pipeline=True, - precision=torch.float16 if self.precision == "float16" else torch.float32, - ) - if self.sequential_offload: - pipeline.enable_offload_submodels(self.device) - else: - pipeline.to(self.device) - return ( - pipeline, - width, - height, - "NOHASH", - ) - - def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path: - if isinstance(model_name, DictConfig) or isinstance(model_name, dict): - mconfig = model_name - elif model_name in self.config: - mconfig = self.config[model_name] - else: - raise ValueError( - f'"{model_name}" is not a known model name. Please check your models.yaml file' - ) - - if "path" in mconfig and mconfig["path"] is not None: - path = Path(mconfig["path"]) - if not path.is_absolute(): - path = Path(Globals.root, path).resolve() - return path - elif "repo_id" in mconfig: - return mconfig["repo_id"] - else: - raise ValueError("Model config must specify either repo_id or path.") - - def offload_model(self, model_name: str) -> None: - """ - Offload the indicated model to CPU. Will call - _make_cache_room() to free space if needed. - """ - if model_name not in self.models: - return - - self.logger.info(f"Offloading {model_name} to CPU") - model = self.models[model_name]["model"] - model.offload_all() - self.current_model = None - - gc.collect() - if self._has_cuda(): - torch.cuda.empty_cache() - - @classmethod - def scan_model(self, model_name, checkpoint): - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - self.logger.debug(f"Scanning Model: {model_name}") - scan_result = scan_file_path(checkpoint) - if scan_result.infected_files != 0: - if scan_result.infected_files == 1: - self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}") - self.logger.critical("The model you are trying to load seems to be infected.") - self.logger.critical("For your safety, InvokeAI will not load this model.") - self.logger.critical("Please use checkpoints from trusted sources.") - self.logger.critical("Exiting InvokeAI") - sys.exit() - else: - self.logger.warning("InvokeAI was unable to scan the model you are using.") - model_safe_check_fail = ask_user( - "Do you want to to continue loading the model?", ["y", "n"] - ) - if model_safe_check_fail.lower() != "y": - self.logger.critical("Exiting InvokeAI") - sys.exit() - else: - self.logger.debug("Model scanned ok") - def import_diffuser_model( self, repo_or_path: Union[str, Path], @@ -949,8 +670,6 @@ class ModelManager(object): new_config = None - from . import convert_ckpt_to_diffusers - if diffusers_path.exists(): self.logger.error( f"The path {str(diffusers_path)} already exists. Please move or remove it and try again." @@ -960,6 +679,10 @@ class ModelManager(object): model_name = model_name or diffusers_path.name model_description = model_description or f"Converted version of {model_name}" self.logger.debug(f"Converting {model_name} to diffusers (30-60s)") + + # to avoid circular import errors + from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + try: # By passing the specified VAE to the conversion function, the autoencoder # will be built into the model rather than tacked on afterward via the config file @@ -1020,33 +743,12 @@ class ModelManager(object): return search_folder, found_models - def _make_cache_room(self) -> None: - num_loaded_models = len(self.models) - if num_loaded_models >= self.max_loaded_models: - least_recent_model = self._pop_oldest_model() - self.logger.info( - f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}" - ) - if least_recent_model is not None: - del self.models[least_recent_model] - gc.collect() - - def print_vram_usage(self) -> None: - if self._has_cuda: - self.logger.info( - "Current VRAM usage:"+ - "%4.2fG" % (torch.cuda.memory_allocated() / 1e9), - ) - - def commit(self, config_file_path: str) -> None: + def commit(self) -> None: """ Write current configuration out to the indicated file. """ yaml_str = OmegaConf.to_yaml(self.config) - if not os.path.isabs(config_file_path): - config_file_path = os.path.normpath( - os.path.join(Globals.root, config_file_path) - ) + config_file_path = self.config_path tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp") with open(tmpfile, "w", encoding="utf-8") as outfile: outfile.write(self.preamble()) @@ -1069,240 +771,6 @@ class ModelManager(object): """ ) - @classmethod - def migrate_models(cls): - """ - Migrate the ~/invokeai/models directory from the legacy format used through 2.2.5 - to the 2.3.0 "diffusers" version. This should be a one-time operation, called at - script startup time. - """ - # Three transformer models to check: bert, clip and safety checker, and - # the diffusers as well - models_dir = Path(Globals.root, "models") - legacy_locations = [ - Path( - models_dir, - "CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker", - ), - Path(models_dir, "bert-base-uncased/models--bert-base-uncased"), - Path( - models_dir, - "openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14", - ), - ] - legacy_locations.extend(list(global_cache_dir("diffusers").glob("*"))) - - legacy_layout = False - for model in legacy_locations: - legacy_layout = legacy_layout or model.exists() - if not legacy_layout: - return - - print( - """ ->> ALERT: ->> The location of your previously-installed diffusers models needs to move from ->> invokeai/models/diffusers to invokeai/models/hub due to a change introduced by ->> diffusers version 0.14. InvokeAI will now move all models from the "diffusers" directory ->> into "hub" and then remove the diffusers directory. This is a quick, safe, one-time ->> operation.""" - ) - - # transformer files get moved into the hub directory - if cls._is_huggingface_hub_directory_present(): - hub = global_cache_dir("hub") - else: - hub = models_dir / "hub" - - os.makedirs(hub, exist_ok=True) - for model in legacy_locations: - source = models_dir / model - dest = hub / model.stem - if dest.exists() and not source.exists(): - continue - cls.logger.info(f"{source} => {dest}") - if source.exists(): - if dest.is_symlink(): - logger.warning(f"Found symlink at {dest.name}. Not migrating.") - elif dest.exists(): - if source.is_dir(): - rmtree(source) - else: - source.unlink() - else: - move(source, dest) - - # now clean up by removing any empty directories - empty = [ - root - for root, dirs, files, in os.walk(models_dir) - if not len(dirs) and not len(files) - ] - for d in empty: - os.rmdir(d) - cls.logger.info("Migration is done. Continuing...") - - def _resolve_path( - self, source: Union[str, Path], dest_directory: str - ) -> Optional[Path]: - resolved_path = None - if str(source).startswith(("http:", "https:", "ftp:")): - dest_directory = Path(dest_directory) - if not dest_directory.is_absolute(): - dest_directory = Globals.root / dest_directory - dest_directory.mkdir(parents=True, exist_ok=True) - resolved_path = download_with_resume(str(source), dest_directory) - else: - if not os.path.isabs(source): - source = os.path.join(Globals.root, source) - resolved_path = Path(source) - return resolved_path - - def _invalidate_cached_model(self, model_name: str) -> None: - self.offload_model(model_name) - if model_name in self.stack: - self.stack.remove(model_name) - self.models.pop(model_name, None) - - def _pop_oldest_model(self): - """ - Remove the first element of the FIFO, which ought - to be the least recently accessed model. Do not - pop the last one, because it is in active use! - """ - return self.stack.pop(0) - - def _push_newest_model(self, model_name: str) -> None: - """ - Maintain a simple FIFO. First element is always the - least recent, and last element is always the most recent. - """ - with contextlib.suppress(ValueError): - self.stack.remove(model_name) - self.stack.append(model_name) - - def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline): - if self.embedding_path is not None: - self.logger.info(f"Loading embeddings from {self.embedding_path}") - for root, _, files in os.walk(self.embedding_path): - for name in files: - ti_path = os.path.join(root, name) - model.textual_inversion_manager.load_textual_inversion( - ti_path, defer_injecting_tokens=True - ) - self.logger.info( - f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}' - ) - - def _has_cuda(self) -> bool: - return self.device.type == "cuda" - - def _diffuser_sha256( - self, name_or_path: Union[str, Path], chunksize=16777216 - ) -> Union[str, bytes]: - path = None - if isinstance(name_or_path, Path): - path = name_or_path - else: - owner, repo = name_or_path.split("/") - path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}") - if not path.exists(): - return None - hashpath = path / "checksum.sha256" - if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime: - with open(hashpath) as f: - hash = f.read() - return hash - self.logger.debug("Calculating sha256 hash of model files") - tic = time.time() - sha = hashlib.sha256() - count = 0 - for root, dirs, files in os.walk(path, followlinks=False): - for name in files: - count += 1 - with open(os.path.join(root, name), "rb") as f: - while chunk := f.read(chunksize): - sha.update(chunk) - hash = sha.hexdigest() - toc = time.time() - self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic)) - with open(hashpath, "w") as f: - f.write(hash) - return hash - - def _cached_sha256(self, path, data) -> Union[str, bytes]: - dirname = os.path.dirname(path) - basename = os.path.basename(path) - base, _ = os.path.splitext(basename) - hashpath = os.path.join(dirname, base + ".sha256") - - if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime( - hashpath - ): - with open(hashpath) as f: - hash = f.read() - return hash - - self.logger.debug("Calculating sha256 hash of weights file") - tic = time.time() - sha = hashlib.sha256() - sha.update(data) - hash = sha.hexdigest() - toc = time.time() - self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic)) - - with open(hashpath, "w") as f: - f.write(hash) - return hash - - def _load_vae(self, vae_config) -> AutoencoderKL: - vae_args = {} - try: - name_or_path = self.model_name_or_path(vae_config) - except Exception: - return None - if name_or_path is None: - return None - using_fp16 = self.precision == "float16" - - vae_args.update( - cache_dir=global_cache_dir("hub"), - local_files_only=not Globals.internet_available, - ) - - self.logger.debug(f"Loading diffusers VAE from {name_or_path}") - if using_fp16: - vae_args.update(torch_dtype=torch.float16) - fp_args_list = [{"revision": "fp16"}, {}] - else: - self.logger.debug("Using more accurate float32 precision") - fp_args_list = [{}] - - vae = None - deferred_error = None - - # A VAE may be in a subfolder of a model's repository. - if "subfolder" in vae_config: - vae_args["subfolder"] = vae_config["subfolder"] - - for fp_args in fp_args_list: - # At some point we might need to be able to use different classes here? But for now I think - # all Stable Diffusion VAE are AutoencoderKL. - try: - vae = AutoencoderKL.from_pretrained(name_or_path, **vae_args, **fp_args) - except OSError as e: - if str(e).startswith("fp16 is not a valid"): - pass - else: - deferred_error = e - if vae: - break - - if not vae and deferred_error: - self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}") - - return vae - @classmethod def _delete_model_from_cache(cls,repo_id): cache_info = scan_cache_dir(global_cache_dir("hub")) @@ -1326,8 +794,3 @@ class ModelManager(object): return path return Path(Globals.root, path).resolve() - @staticmethod - def _is_huggingface_hub_directory_present() -> bool: - return ( - os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None - )