mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model manager rewritten to use model_cache; API changed!
This commit is contained in:
parent
a4e36bc02a
commit
af8c7c7d29
@ -10,7 +10,7 @@ from .generator import (
|
||||
Img2Img,
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager, SDModelComponent
|
||||
from .model_management import ModelManager
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
@ -93,6 +93,8 @@ def global_converted_ckpts_dir() -> Path:
|
||||
def global_set_root(root_dir: Union[str, Path]):
|
||||
Globals.root = root_dir
|
||||
|
||||
def global_resolve_path(path: Union[str,Path]):
|
||||
return Path(Globals.root,path).resolve()
|
||||
|
||||
def global_cache_dir(subdir: Union[str, Path] = "") -> Path:
|
||||
"""
|
||||
|
@ -1,11 +1,5 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
from .convert_ckpt_to_diffusers import (
|
||||
convert_ckpt_to_diffusers,
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from .model_manager import ModelManager,SDModelComponent
|
||||
|
||||
|
||||
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache, ModelStatus
|
||||
|
@ -29,6 +29,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.backend.globals import global_cache_dir, global_config_dir
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .model_cache import ModelCache
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
@ -1100,7 +1101,7 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
if scan_needed:
|
||||
ModelManager.scan_model(checkpoint_path,checkpoint_path)
|
||||
ModelCache.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
@ -24,10 +24,10 @@ from collections.abc import Generator
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
from typing import Sequence, Union, Tuple, types
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel
|
||||
from diffusers import logging as diffusers_logging
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
@ -40,7 +40,6 @@ from transformers import logging as transformers_logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import global_cache_dir
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
MAX_MODELS = 4
|
||||
|
||||
@ -55,8 +54,16 @@ class SDModelType(Enum):
|
||||
safety_checker=StableDiffusionSafetyChecker
|
||||
feature_extractor=CLIPFeatureExtractor
|
||||
|
||||
class ModelStatus(Enum):
|
||||
unknown='unknown'
|
||||
not_loaded='not loaded'
|
||||
in_ram='cached'
|
||||
in_vram='in gpu'
|
||||
active='locked in gpu'
|
||||
|
||||
# The list of model classes we know how to fetch, for typechecking
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
||||
|
||||
# Legacy information needed to load a legacy checkpoint file
|
||||
class LegacyInfo(BaseModel):
|
||||
@ -81,6 +88,7 @@ class ModelCache(object):
|
||||
sequential_offload: bool=False,
|
||||
lazy_offloading: bool=True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger
|
||||
):
|
||||
'''
|
||||
:param max_models: Maximum number of models to cache in CPU RAM [4]
|
||||
@ -100,9 +108,11 @@ class ModelCache(object):
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.logger = logger
|
||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_model(
|
||||
self,
|
||||
@ -112,6 +122,7 @@ class ModelCache(object):
|
||||
submodel: SDModelType=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
||||
gpu_load: bool=True,
|
||||
)->Generator[ModelClass, None, None]:
|
||||
'''
|
||||
@ -122,10 +133,28 @@ class ModelCache(object):
|
||||
with cache.get_model('stabilityai/stable-diffusion-2') as SD2:
|
||||
do_something_with_the_model(SD2)
|
||||
|
||||
You can fetch an individual part of a diffusers model by passing the submodel
|
||||
argument:
|
||||
|
||||
vae_context = cache.get_model(
|
||||
'stabilityai/sd-stable-diffusion-2',
|
||||
submodel=SDModelType.vae
|
||||
)
|
||||
|
||||
Vice versa, you can load and attach an external submodel to a diffusers model
|
||||
before returning it by passing the attach_submodel argument. This only works with
|
||||
diffusers models:
|
||||
|
||||
pipeline_context = cache.get_model(
|
||||
'runwayml/stable-diffusion-v1-5',
|
||||
attach_model_part=(SDModelType.vae,'stabilityai/sd-vae-ft-mse')
|
||||
)
|
||||
|
||||
The model will be locked into GPU VRAM for the duration of the context.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
:param gpu_load: load the model into GPU [default True]
|
||||
@ -151,6 +180,8 @@ class ModelCache(object):
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]:
|
||||
self.attach_part(model,*attach_model_part)
|
||||
self.stack.append(key) # add to LRU cache
|
||||
self.models[key]=model # keep copy of model in dict
|
||||
|
||||
@ -163,7 +194,7 @@ class ModelCache(object):
|
||||
self.locked_models[key] += 1
|
||||
if self.lazy_offloading:
|
||||
self._offload_unlocked_models()
|
||||
logger.debug(f'Loading {key} into {self.execution_device}')
|
||||
self.logger.debug(f'Loading {key} into {self.execution_device}')
|
||||
model.to(self.execution_device) # move into GPU
|
||||
self._print_cuda_stats()
|
||||
yield model
|
||||
@ -181,25 +212,59 @@ class ModelCache(object):
|
||||
self.loaded_models.remove(key)
|
||||
yield model
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
to_offload = set()
|
||||
for key in self.loaded_models:
|
||||
if key not in self.locked_models or self.locked_models[key] == 0:
|
||||
logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
||||
to_offload.add(key)
|
||||
for key in to_offload:
|
||||
self.models[key].to(self.storage_device)
|
||||
self.loaded_models.remove(key)
|
||||
def attach_part(self,
|
||||
diffusers_model: StableDiffusionPipeline,
|
||||
part_type: SDModelType,
|
||||
part_id: str
|
||||
):
|
||||
'''
|
||||
Attach a diffusers model part to a diffusers model. This can be
|
||||
used to replace the VAE, tokenizer, textencoder, unet, etc.
|
||||
:param diffuser_model: The diffusers model to attach the part to.
|
||||
:param part_type: An SD ModelType indicating the part
|
||||
:param part_id: A HF repo_id for the part
|
||||
'''
|
||||
part_key = part_type.name
|
||||
part_class = part_type.value
|
||||
part = self._load_diffusers_from_storage(
|
||||
part_id,
|
||||
model_class=part_class,
|
||||
)
|
||||
part.to(diffusers_model.device)
|
||||
setattr(diffusers_model,part_key,part)
|
||||
self.logger.debug(f'Attached {part_key} {part_id}')
|
||||
|
||||
def status(self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
model_type: SDModelType=SDModelType.diffusion_pipeline,
|
||||
revision: str=None,
|
||||
subfolder: Path=None,
|
||||
)->ModelStatus:
|
||||
key = self._model_key(
|
||||
repo_id_or_path,
|
||||
model_type.value,
|
||||
revision,
|
||||
subfolder)
|
||||
if key not in self.models:
|
||||
return ModelStatus.not_loaded
|
||||
if key in self.loaded_models:
|
||||
if self.locked_models[key] > 0:
|
||||
return ModelStatus.active
|
||||
else:
|
||||
return ModelStatus.in_vram
|
||||
else:
|
||||
return ModelStatus.in_ram
|
||||
|
||||
def model_hash(self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
revision: str=None)->str:
|
||||
revision: str="main")->str:
|
||||
'''
|
||||
Given the HF repo id or path to a model on disk, returns a unique
|
||||
hash. Works for legacy checkpoint files, HF models on disk, and HF repo IDs
|
||||
:param repo_id_or_path: repo_id string or Path to model file/directory on disk.
|
||||
:param revision: optional revision string (if fetching a HF repo_id)
|
||||
'''
|
||||
revision = revision or "main"
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
return self._legacy_model_hash(repo_id_or_path)
|
||||
elif Path(repo_id_or_path).is_dir():
|
||||
@ -211,96 +276,6 @@ class ModelCache(object):
|
||||
"Return the current number of models cached."
|
||||
return len(self.models)
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path,model_class,revision,subfolder)->str:
|
||||
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||
|
||||
def _has_cuda(self)->bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
loaded_models = len(self.loaded_models)
|
||||
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
||||
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
gc.collect()
|
||||
|
||||
@property
|
||||
def current_model(self)->ModelClass:
|
||||
'''
|
||||
Returns current model.
|
||||
'''
|
||||
return self.models[self._current_model_key]
|
||||
|
||||
@property
|
||||
def _current_model_key(self)->str:
|
||||
'''
|
||||
Returns key of currently loaded model.
|
||||
'''
|
||||
return self.stack[-1]
|
||||
|
||||
def _load_model_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
elif hasattr(model,'to'):
|
||||
model.to(self.execution_device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model using from_pretrained().
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
'''
|
||||
return model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=revision,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
||||
'''
|
||||
@ -327,6 +302,106 @@ class ModelCache(object):
|
||||
else:
|
||||
logger.debug("Model scanned ok")
|
||||
|
||||
@staticmethod
|
||||
def _model_key(path,model_class,revision,subfolder)->str:
|
||||
return ':'.join([str(path),model_class.__name__,str(revision or ''),str(subfolder or '')])
|
||||
|
||||
def _has_cuda(self)->bool:
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
loaded_models = len(self.loaded_models)
|
||||
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
||||
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
gc.collect()
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
to_offload = set()
|
||||
for key in self.loaded_models:
|
||||
if key not in self.locked_models or self.locked_models[key] == 0:
|
||||
self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
||||
to_offload.add(key)
|
||||
for key in to_offload:
|
||||
self.models[key].to(self.storage_device)
|
||||
self.loaded_models.remove(key)
|
||||
|
||||
def _load_model_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
legacy_info: LegacyInfo=None,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
if self.is_legacy_ckpt(repo_id_or_path):
|
||||
model = self._load_ckpt_from_storage(repo_id_or_path, legacy_info)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model using from_pretrained().
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
'''
|
||||
self.logger.info(f'Loading model {repo_id_or_path}')
|
||||
revisions = [revision] if revision \
|
||||
else ['fp16','main'] if self.precision==torch.float16 \
|
||||
else ['main']
|
||||
extra_args = {'precision': self.precision} \
|
||||
if model_class in DiffusionClasses \
|
||||
else {}
|
||||
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
for rev in revisions:
|
||||
try:
|
||||
model = model_class.from_pretrained(
|
||||
repo_id_or_path,
|
||||
revision=rev,
|
||||
subfolder=subfolder or '.',
|
||||
cache_dir=global_cache_dir('hub'),
|
||||
**extra_args,
|
||||
)
|
||||
self.logger.debug(f'Found revision {rev}')
|
||||
break
|
||||
except OSError:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
ckpt_path: Union[str,Path],
|
||||
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
||||
@ -336,6 +411,10 @@ class ModelCache(object):
|
||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
||||
'''
|
||||
assert legacy_info is not None
|
||||
|
||||
# deferred loading to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
with SilenceWarnings():
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=ckpt_path,
|
||||
original_config_file=legacy_info.config_file,
|
||||
@ -399,6 +478,7 @@ class ModelCache(object):
|
||||
raise KeyError(f"Revision '{revision}' not found in {repo_id}")
|
||||
return desired_revisions[0].target_commit
|
||||
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
|
@ -1,56 +1,88 @@
|
||||
"""
|
||||
Manage a cache of Stable Diffusion model files for fast switching.
|
||||
They are moved between GPU and CPU as necessary. If CPU memory falls
|
||||
below a preset minimum, the least recently used model will be
|
||||
cleared and loaded from disk when next needed.
|
||||
"""This module manages the InvokeAI `models.yaml` file, mapping
|
||||
symbolic diffusers model names to the paths and repo_ids used
|
||||
by the underlying `from_pretrained()` call.
|
||||
|
||||
For fetching models, use manager.get_model('symbolic name'). This will
|
||||
return a SDModelInfo object that contains the following attributes:
|
||||
|
||||
* context -- a context manager Generator that loads and locks the
|
||||
model into GPU VRAM and returns the model for use.
|
||||
See below for usage.
|
||||
* name -- symbolic name of the model
|
||||
* hash -- unique hash for the model
|
||||
* location -- path or repo_id of the model
|
||||
* revision -- revision of the model if coming from a repo id,
|
||||
e.g. 'fp16'
|
||||
* precision -- torch precision of the model
|
||||
* status -- a ModelStatus enum corresponding to one of
|
||||
'not_loaded', 'in_ram', 'in_vram' or 'active'
|
||||
|
||||
Typical usage:
|
||||
|
||||
from invokeai.backend import ModelManager
|
||||
manager = ModelManager(config_path='./configs/models.yaml',max_models=4)
|
||||
model_info = manager.get_model('stable-diffusion-1.5')
|
||||
with model_info.context as my_model:
|
||||
my_model.latents_from_embeddings(...)
|
||||
|
||||
The manager uses the underlying ModelCache class to keep
|
||||
frequently-used models in RAM and move them into GPU as needed for
|
||||
generation operations. The ModelCache object can be accessed using
|
||||
the manager's "cache" attribute.
|
||||
|
||||
Other methods provided by ModelManager support importing, editing,
|
||||
converting and deleting models.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable, types
|
||||
from shutil import rmtree
|
||||
from typing import Union, Callable, types
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import transformers
|
||||
import invokeai.backend.util.logging as logger
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
UNet2DConditionModel,
|
||||
SchedulerMixin,
|
||||
logging as dlogging,
|
||||
)
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
from invokeai.backend.globals import Globals, global_cache_dir, global_resolve_path
|
||||
from .model_cache import Generator, ModelClass, ModelCache, SDModelType, ModelStatus, LegacyInfo
|
||||
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
from ..util import CUDA_DEVICE
|
||||
|
||||
# wanted to use pydantic here, but Generator objects not supported
|
||||
@dataclass
|
||||
class SDModelInfo():
|
||||
context: Generator[ModelClass, None, None]
|
||||
name: str
|
||||
hash: str
|
||||
location: Union[Path,str]
|
||||
precision: torch.dtype
|
||||
subfolder: Path = None
|
||||
revision: str = None
|
||||
_cache: ModelCache = None
|
||||
|
||||
|
||||
def status(self)->ModelStatus:
|
||||
'''Return load status of this model as a model_cache.ModelStatus enum'''
|
||||
if not self._cache:
|
||||
return ModelStatus.unknown
|
||||
return self._cache.status(
|
||||
self.location,
|
||||
revision = self.revision,
|
||||
subfolder = self.subfolder
|
||||
)
|
||||
from ..stable_diffusion import (
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class InvalidModelError(Exception):
|
||||
"Raised when an invalid model is requested"
|
||||
pass
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = auto()
|
||||
@ -60,54 +92,39 @@ class SDLegacyType(Enum):
|
||||
V2_v = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
class SDModelComponent(Enum):
|
||||
vae="vae"
|
||||
text_encoder="text_encoder"
|
||||
tokenizer="tokenizer"
|
||||
unet="unet"
|
||||
scheduler="scheduler"
|
||||
safety_checker="safety_checker"
|
||||
feature_extractor="feature_extractor"
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
|
||||
class ModelManager(object):
|
||||
"""
|
||||
Model manager handles loading, caching, importing, deleting, converting, and editing models.
|
||||
High-level interface to model management.
|
||||
"""
|
||||
|
||||
logger: types.ModuleType = logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OmegaConf | Path,
|
||||
config_path: Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: str = "float16",
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_models=DEFAULT_MAX_MODELS,
|
||||
sequential_offload=False,
|
||||
embedding_path: Path = None,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
"""
|
||||
Initialize with the path to the models.yaml config file or
|
||||
an initialized OmegaConf dictionary. Optional parameters
|
||||
are the torch device type, precision, max_loaded_models,
|
||||
Initialize with the path to the models.yaml config file.
|
||||
Optional parameters are the torch device type, precision, max_models,
|
||||
and sequential_offload boolean. Note that the default device
|
||||
type and precision are set up for a CUDA system running at half precision.
|
||||
"""
|
||||
# prevent nasty-looking CLIP log message
|
||||
transformers.logging.set_verbosity_error()
|
||||
if not isinstance(config, DictConfig):
|
||||
config = OmegaConf.load(config)
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.max_loaded_models = max_loaded_models
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
self.sequential_offload = sequential_offload
|
||||
self.embedding_path = embedding_path
|
||||
self.config_path = config_path
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
self.cache = ModelCache(
|
||||
max_models=max_models,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
logger = logger,
|
||||
)
|
||||
self.logger = logger
|
||||
|
||||
def valid_model(self, model_name: str) -> bool:
|
||||
@ -117,138 +134,69 @@ class ModelManager(object):
|
||||
"""
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name: str = None) -> dict:
|
||||
"""Given a model named identified in models.yaml, return a dict
|
||||
containing the model object and some of its key features. If
|
||||
in RAM will load into GPU VRAM. If on disk, will load from
|
||||
there.
|
||||
The dict has the following keys:
|
||||
'model': The StableDiffusionGeneratorPipeline object
|
||||
'model_name': The name of the model in models.yaml
|
||||
'width': The width of images trained by this model
|
||||
'height': The height of images trained by this model
|
||||
'hash': A unique hash of this model's files on disk.
|
||||
def get_model(self,
|
||||
model_name: str = None,
|
||||
submodel: SDModelType=None,
|
||||
) -> SDModelInfo:
|
||||
"""Given a model named identified in models.yaml, return
|
||||
an SDModelInfo object describing it.
|
||||
:param model_name: symbolic name of the model in models.yaml
|
||||
:param submodel: an SDModelType enum indicating the portion of
|
||||
the model to retrieve (e.g. SDModelType.vae)
|
||||
"""
|
||||
if not model_name:
|
||||
return (
|
||||
self.get_model(self.current_model)
|
||||
if self.current_model
|
||||
else self.get_model(self.default_model())
|
||||
)
|
||||
model_name = self.default_model()
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
self.logger.error(
|
||||
raise InvalidModelError(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return self.current_model
|
||||
|
||||
if self.current_model != model_name:
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
# get the required loading info out of the config file
|
||||
mconfig = self.config[model_name]
|
||||
format = mconfig.get('format','diffusers')
|
||||
legacy = None
|
||||
if format=='ckpt':
|
||||
location = global_resolve_path(mconfig.weights)
|
||||
legacy = LegacyInfo(
|
||||
config_file = global_resolve_path(mconfig.config),
|
||||
)
|
||||
if mconfig.get('vae'):
|
||||
legacy.vae_file = global_resolve_path(mconfig.vae)
|
||||
elif format=='diffusers':
|
||||
location = mconfig.repo_id
|
||||
revision = mconfig.get('revision')
|
||||
else:
|
||||
raise InvalidModelError(
|
||||
f'"{model_name}" has an unknown format {format}'
|
||||
)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]["model"]
|
||||
self.logger.info(f"Retrieving model {model_name} from system RAM cache")
|
||||
requested_model.ready()
|
||||
width = self.models[model_name]["width"]
|
||||
height = self.models[model_name]["height"]
|
||||
hash = self.models[model_name]["hash"]
|
||||
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {
|
||||
"model_name": model_name,
|
||||
"model": requested_model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"hash": hash,
|
||||
}
|
||||
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"model": requested_model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"hash": hash,
|
||||
}
|
||||
|
||||
def get_model_vae(self, model_name: str=None)->AutoencoderKL:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned VAE as an
|
||||
AutoencoderKL object. If no model name is provided, return the
|
||||
vae from the model currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.vae)
|
||||
|
||||
def get_model_tokenizer(self, model_name: str=None)->CLIPTokenizer:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTokenizer. If no
|
||||
model name is provided, return the tokenizer from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.tokenizer)
|
||||
|
||||
def get_model_unet(self, model_name: str=None)->UNet2DConditionModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned UNet2DConditionModel. If no model
|
||||
name is provided, return the UNet from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.unet)
|
||||
|
||||
def get_model_text_encoder(self, model_name: str=None)->CLIPTextModel:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPTextModel. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.text_encoder)
|
||||
|
||||
def get_model_feature_extractor(self, model_name: str=None)->CLIPFeatureExtractor:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned CLIPFeatureExtractor. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.feature_extractor)
|
||||
|
||||
def get_model_scheduler(self, model_name: str=None)->SchedulerMixin:
|
||||
"""Given a model name identified in models.yaml, load the model into
|
||||
GPU if necessary and return its assigned scheduler. If no
|
||||
model name is provided, return the text encoder from the model
|
||||
currently in the GPU.
|
||||
"""
|
||||
return self._get_sub_model(model_name, SDModelComponent.scheduler)
|
||||
|
||||
def _get_sub_model(
|
||||
self,
|
||||
model_name: str=None,
|
||||
model_part: SDModelComponent=SDModelComponent.vae,
|
||||
) -> Union[
|
||||
AutoencoderKL,
|
||||
CLIPTokenizer,
|
||||
CLIPFeatureExtractor,
|
||||
UNet2DConditionModel,
|
||||
CLIPTextModel,
|
||||
StableDiffusionSafetyChecker,
|
||||
]:
|
||||
"""Given a model name identified in models.yaml, and the part of the
|
||||
model you wish to retrieve, return that part. Parts are in an Enum
|
||||
class named SDModelComponent, and consist of:
|
||||
SDModelComponent.vae
|
||||
SDModelComponent.text_encoder
|
||||
SDModelComponent.tokenizer
|
||||
SDModelComponent.unet
|
||||
SDModelComponent.scheduler
|
||||
SDModelComponent.safety_checker
|
||||
SDModelComponent.feature_extractor
|
||||
"""
|
||||
model_dict = self.get_model(model_name)
|
||||
model = model_dict["model"]
|
||||
return getattr(model, model_part.value)
|
||||
subfolder = mconfig.get('subfolder')
|
||||
hash = self.cache.model_hash(location,revision)
|
||||
vae = (None,None)
|
||||
try:
|
||||
vae_id = mconfig.vae.repo_id
|
||||
vae = (SDModelType.vae,vae_id)
|
||||
except Exception:
|
||||
pass
|
||||
model_context = self.cache.get_model(
|
||||
location,
|
||||
revision = revision,
|
||||
subfolder = subfolder,
|
||||
legacy_info = legacy,
|
||||
submodel = submodel,
|
||||
attach_model_part=vae,
|
||||
)
|
||||
return SDModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
hash = hash,
|
||||
location = location,
|
||||
revision = revision,
|
||||
precision = self.cache.precision,
|
||||
subfolder = subfolder,
|
||||
_cache = self.cache
|
||||
)
|
||||
|
||||
def default_model(self) -> str | None:
|
||||
"""
|
||||
@ -324,19 +272,19 @@ class ModelManager(object):
|
||||
format = stanza.get("format", "ckpt") # Determine Format
|
||||
|
||||
# Common Attribs
|
||||
status = self.cache.status(
|
||||
stanza.get('weights') or stanza.get('repo_id'),
|
||||
revision=stanza.get('revision'),
|
||||
subfolder=stanza.get('subfolder')
|
||||
)
|
||||
description = stanza.get("description", None)
|
||||
if self.current_model == name:
|
||||
status = "active"
|
||||
elif name in self.models:
|
||||
status = "cached"
|
||||
else:
|
||||
status = "not loaded"
|
||||
models[name].update(
|
||||
description=description,
|
||||
format=format,
|
||||
status=status,
|
||||
status=status.value
|
||||
)
|
||||
|
||||
|
||||
# Checkpoint Config Parse
|
||||
if format == "ckpt":
|
||||
models[name].update(
|
||||
@ -373,7 +321,7 @@ class ModelManager(object):
|
||||
for name in models:
|
||||
if models[name]["format"] == "vae":
|
||||
continue
|
||||
line = f'{name:25s} {models[name]["status"]:>10s} {models[name]["format"]:10s} {models[name]["description"]}'
|
||||
line = f'{name:25s} {models[name]["status"]:>15s} {models[name]["format"]:10s} {models[name]["description"]}'
|
||||
if models[name]["status"] == "active":
|
||||
line = f"\033[1m{line}\033[0m"
|
||||
print(line)
|
||||
@ -441,233 +389,6 @@ class ModelManager(object):
|
||||
if clobber:
|
||||
self._invalidate_cached_model(model_name)
|
||||
|
||||
def _load_model(self, model_name: str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
self.logger.error(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
return
|
||||
|
||||
mconfig = self.config[model_name]
|
||||
|
||||
# for usage statistics
|
||||
if self._has_cuda():
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
model_format = mconfig.get("format", "ckpt")
|
||||
if model_format == "ckpt":
|
||||
weights = mconfig.weights
|
||||
self.logger.info(f"Loading {model_name} from {weights}")
|
||||
model, width, height, model_hash = self._load_ckpt_model(
|
||||
model_name, mconfig
|
||||
)
|
||||
elif model_format == "diffusers":
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
model, width, height, model_hash = self._load_diffusers_model(mconfig)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown model format {model_name}: {model_format}"
|
||||
)
|
||||
self._add_embeddings_to_model(model)
|
||||
|
||||
# usage statistics
|
||||
toc = time.time()
|
||||
self.logger.info("Model loaded in " + "%4.2fs" % (toc - tic))
|
||||
if self._has_cuda():
|
||||
self.logger.info(
|
||||
"Max VRAM used to load the model: "+
|
||||
"%4.2fG" % (torch.cuda.max_memory_allocated() / 1e9)
|
||||
)
|
||||
self.logger.info(
|
||||
"Current VRAM usage: "+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def _load_diffusers_model(self, mconfig):
|
||||
name_or_path = self.model_name_or_path(mconfig)
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
self.logger.info(f"Loading diffusers model from {name_or_path}")
|
||||
if using_fp16:
|
||||
self.logger.debug("Using faster float16 precision")
|
||||
else:
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
|
||||
# TODO: scan weights maybe?
|
||||
pipeline_args: dict[str, Any] = dict(
|
||||
safety_checker=None, local_files_only=not Globals.internet_available
|
||||
)
|
||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||
if vae := self._load_vae(mconfig["vae"]):
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path, Path):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("hub"))
|
||||
if using_fp16:
|
||||
pipeline_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
fp_args_list = [{}]
|
||||
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
pipeline = None
|
||||
for fp_args in fp_args_list:
|
||||
try:
|
||||
pipeline = StableDiffusionGeneratorPipeline.from_pretrained(
|
||||
name_or_path,
|
||||
**pipeline_args,
|
||||
**fp_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
self.logger.error(
|
||||
f"An unexpected error occurred while downloading the model: {e})"
|
||||
)
|
||||
if pipeline:
|
||||
break
|
||||
|
||||
dlogging.set_verbosity(verbosity)
|
||||
assert pipeline is not None, OSError(f'"{name_or_path}" could not be loaded')
|
||||
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
|
||||
model_hash = self._diffuser_sha256(name_or_path)
|
||||
|
||||
# square images???
|
||||
width = pipeline.unet.config.sample_size * pipeline.vae_scale_factor
|
||||
height = width
|
||||
self.logger.debug(f"Default image dimensions = {width} x {height}")
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def _load_ckpt_model(self, model_name, mconfig):
|
||||
config = mconfig.config
|
||||
weights = mconfig.weights
|
||||
vae = mconfig.get("vae")
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root, config)
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# Convert to diffusers and return a diffusers pipeline
|
||||
self.logger.info(f"Converting legacy checkpoint {model_name} into a diffusers model...")
|
||||
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
try:
|
||||
if self.list_models()[self.current_model]["status"] == "active":
|
||||
self.offload_model(self.current_model)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
vae_path = None
|
||||
if vae:
|
||||
vae_path = (
|
||||
vae
|
||||
if os.path.isabs(vae)
|
||||
else os.path.normpath(os.path.join(Globals.root, vae))
|
||||
)
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=weights,
|
||||
original_config_file=config,
|
||||
vae_path=vae_path,
|
||||
return_generator_pipeline=True,
|
||||
precision=torch.float16 if self.precision == "float16" else torch.float32,
|
||||
)
|
||||
if self.sequential_offload:
|
||||
pipeline.enable_offload_submodels(self.device)
|
||||
else:
|
||||
pipeline.to(self.device)
|
||||
return (
|
||||
pipeline,
|
||||
width,
|
||||
height,
|
||||
"NOHASH",
|
||||
)
|
||||
|
||||
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
|
||||
if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
|
||||
mconfig = model_name
|
||||
elif model_name in self.config:
|
||||
mconfig = self.config[model_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
|
||||
if "path" in mconfig and mconfig["path"] is not None:
|
||||
path = Path(mconfig["path"])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
return path
|
||||
elif "repo_id" in mconfig:
|
||||
return mconfig["repo_id"]
|
||||
else:
|
||||
raise ValueError("Model config must specify either repo_id or path.")
|
||||
|
||||
def offload_model(self, model_name: str) -> None:
|
||||
"""
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
"""
|
||||
if model_name not in self.models:
|
||||
return
|
||||
|
||||
self.logger.info(f"Offloading {model_name} to CPU")
|
||||
model = self.models[model_name]["model"]
|
||||
model.offload_all()
|
||||
self.current_model = None
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def scan_model(self, model_name, checkpoint):
|
||||
"""
|
||||
Apply picklescanner to the indicated checkpoint and issue a warning
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
self.logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
self.logger.critical(f"Issues Found In Model: {scan_result.issues_count}")
|
||||
self.logger.critical("The model you are trying to load seems to be infected.")
|
||||
self.logger.critical("For your safety, InvokeAI will not load this model.")
|
||||
self.logger.critical("Please use checkpoints from trusted sources.")
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
self.logger.warning("InvokeAI was unable to scan the model you are using.")
|
||||
model_safe_check_fail = ask_user(
|
||||
"Do you want to to continue loading the model?", ["y", "n"]
|
||||
)
|
||||
if model_safe_check_fail.lower() != "y":
|
||||
self.logger.critical("Exiting InvokeAI")
|
||||
sys.exit()
|
||||
else:
|
||||
self.logger.debug("Model scanned ok")
|
||||
|
||||
def import_diffuser_model(
|
||||
self,
|
||||
repo_or_path: Union[str, Path],
|
||||
@ -949,8 +670,6 @@ class ModelManager(object):
|
||||
|
||||
new_config = None
|
||||
|
||||
from . import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
self.logger.error(
|
||||
f"The path {str(diffusers_path)} already exists. Please move or remove it and try again."
|
||||
@ -960,6 +679,10 @@ class ModelManager(object):
|
||||
model_name = model_name or diffusers_path.name
|
||||
model_description = model_description or f"Converted version of {model_name}"
|
||||
self.logger.debug(f"Converting {model_name} to diffusers (30-60s)")
|
||||
|
||||
# to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
|
||||
try:
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
@ -1020,33 +743,12 @@ class ModelManager(object):
|
||||
|
||||
return search_folder, found_models
|
||||
|
||||
def _make_cache_room(self) -> None:
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
self.logger.info(
|
||||
f"Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}"
|
||||
)
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
def print_vram_usage(self) -> None:
|
||||
if self._has_cuda:
|
||||
self.logger.info(
|
||||
"Current VRAM usage:"+
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
|
||||
def commit(self, config_file_path: str) -> None:
|
||||
def commit(self) -> None:
|
||||
"""
|
||||
Write current configuration out to the indicated file.
|
||||
"""
|
||||
yaml_str = OmegaConf.to_yaml(self.config)
|
||||
if not os.path.isabs(config_file_path):
|
||||
config_file_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config_file_path)
|
||||
)
|
||||
config_file_path = self.config_path
|
||||
tmpfile = os.path.join(os.path.dirname(config_file_path), "new_config.tmp")
|
||||
with open(tmpfile, "w", encoding="utf-8") as outfile:
|
||||
outfile.write(self.preamble())
|
||||
@ -1069,240 +771,6 @@ class ModelManager(object):
|
||||
"""
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def migrate_models(cls):
|
||||
"""
|
||||
Migrate the ~/invokeai/models directory from the legacy format used through 2.2.5
|
||||
to the 2.3.0 "diffusers" version. This should be a one-time operation, called at
|
||||
script startup time.
|
||||
"""
|
||||
# Three transformer models to check: bert, clip and safety checker, and
|
||||
# the diffusers as well
|
||||
models_dir = Path(Globals.root, "models")
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||
),
|
||||
Path(models_dir, "bert-base-uncased/models--bert-base-uncased"),
|
||||
Path(
|
||||
models_dir,
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
if not legacy_layout:
|
||||
return
|
||||
|
||||
print(
|
||||
"""
|
||||
>> ALERT:
|
||||
>> The location of your previously-installed diffusers models needs to move from
|
||||
>> invokeai/models/diffusers to invokeai/models/hub due to a change introduced by
|
||||
>> diffusers version 0.14. InvokeAI will now move all models from the "diffusers" directory
|
||||
>> into "hub" and then remove the diffusers directory. This is a quick, safe, one-time
|
||||
>> operation."""
|
||||
)
|
||||
|
||||
# transformer files get moved into the hub directory
|
||||
if cls._is_huggingface_hub_directory_present():
|
||||
hub = global_cache_dir("hub")
|
||||
else:
|
||||
hub = models_dir / "hub"
|
||||
|
||||
os.makedirs(hub, exist_ok=True)
|
||||
for model in legacy_locations:
|
||||
source = models_dir / model
|
||||
dest = hub / model.stem
|
||||
if dest.exists() and not source.exists():
|
||||
continue
|
||||
cls.logger.info(f"{source} => {dest}")
|
||||
if source.exists():
|
||||
if dest.is_symlink():
|
||||
logger.warning(f"Found symlink at {dest.name}. Not migrating.")
|
||||
elif dest.exists():
|
||||
if source.is_dir():
|
||||
rmtree(source)
|
||||
else:
|
||||
source.unlink()
|
||||
else:
|
||||
move(source, dest)
|
||||
|
||||
# now clean up by removing any empty directories
|
||||
empty = [
|
||||
root
|
||||
for root, dirs, files, in os.walk(models_dir)
|
||||
if not len(dirs) and not len(files)
|
||||
]
|
||||
for d in empty:
|
||||
os.rmdir(d)
|
||||
cls.logger.info("Migration is done. Continuing...")
|
||||
|
||||
def _resolve_path(
|
||||
self, source: Union[str, Path], dest_directory: str
|
||||
) -> Optional[Path]:
|
||||
resolved_path = None
|
||||
if str(source).startswith(("http:", "https:", "ftp:")):
|
||||
dest_directory = Path(dest_directory)
|
||||
if not dest_directory.is_absolute():
|
||||
dest_directory = Globals.root / dest_directory
|
||||
dest_directory.mkdir(parents=True, exist_ok=True)
|
||||
resolved_path = download_with_resume(str(source), dest_directory)
|
||||
else:
|
||||
if not os.path.isabs(source):
|
||||
source = os.path.join(Globals.root, source)
|
||||
resolved_path = Path(source)
|
||||
return resolved_path
|
||||
|
||||
def _invalidate_cached_model(self, model_name: str) -> None:
|
||||
self.offload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name, None)
|
||||
|
||||
def _pop_oldest_model(self):
|
||||
"""
|
||||
Remove the first element of the FIFO, which ought
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
"""
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self, model_name: str) -> None:
|
||||
"""
|
||||
Maintain a simple FIFO. First element is always the
|
||||
least recent, and last element is always the most recent.
|
||||
"""
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(model_name)
|
||||
self.stack.append(model_name)
|
||||
|
||||
def _add_embeddings_to_model(self, model: StableDiffusionGeneratorPipeline):
|
||||
if self.embedding_path is not None:
|
||||
self.logger.info(f"Loading embeddings from {self.embedding_path}")
|
||||
for root, _, files in os.walk(self.embedding_path):
|
||||
for name in files:
|
||||
ti_path = os.path.join(root, name)
|
||||
model.textual_inversion_manager.load_textual_inversion(
|
||||
ti_path, defer_injecting_tokens=True
|
||||
)
|
||||
self.logger.info(
|
||||
f'Textual inversion triggers: {", ".join(sorted(model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||
)
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
return self.device.type == "cuda"
|
||||
|
||||
def _diffuser_sha256(
|
||||
self, name_or_path: Union[str, Path], chunksize=16777216
|
||||
) -> Union[str, bytes]:
|
||||
path = None
|
||||
if isinstance(name_or_path, Path):
|
||||
path = name_or_path
|
||||
else:
|
||||
owner, repo = name_or_path.split("/")
|
||||
path = Path(global_cache_dir("hub") / f"models--{owner}--{repo}")
|
||||
if not path.exists():
|
||||
return None
|
||||
hashpath = path / "checksum.sha256"
|
||||
if hashpath.exists() and path.stat().st_mtime <= hashpath.stat().st_mtime:
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
self.logger.debug("Calculating sha256 hash of model files")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
count = 0
|
||||
for root, dirs, files in os.walk(path, followlinks=False):
|
||||
for name in files:
|
||||
count += 1
|
||||
with open(os.path.join(root, name), "rb") as f:
|
||||
while chunk := f.read(chunksize):
|
||||
sha.update(chunk)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
self.logger.debug(f"sha256 = {hash} ({count} files hashed in", "%4.2fs)" % (toc - tic))
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _cached_sha256(self, path, data) -> Union[str, bytes]:
|
||||
dirname = os.path.dirname(path)
|
||||
basename = os.path.basename(path)
|
||||
base, _ = os.path.splitext(basename)
|
||||
hashpath = os.path.join(dirname, base + ".sha256")
|
||||
|
||||
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(
|
||||
hashpath
|
||||
):
|
||||
with open(hashpath) as f:
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
self.logger.debug("Calculating sha256 hash of weights file")
|
||||
tic = time.time()
|
||||
sha = hashlib.sha256()
|
||||
sha.update(data)
|
||||
hash = sha.hexdigest()
|
||||
toc = time.time()
|
||||
self.logger.debug(f"sha256 = {hash} "+"(%4.2fs)" % (toc - tic))
|
||||
|
||||
with open(hashpath, "w") as f:
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||
vae_args = {}
|
||||
try:
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
except Exception:
|
||||
return None
|
||||
if name_or_path is None:
|
||||
return None
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
vae_args.update(
|
||||
cache_dir=global_cache_dir("hub"),
|
||||
local_files_only=not Globals.internet_available,
|
||||
)
|
||||
|
||||
self.logger.debug(f"Loading diffusers VAE from {name_or_path}")
|
||||
if using_fp16:
|
||||
vae_args.update(torch_dtype=torch.float16)
|
||||
fp_args_list = [{"revision": "fp16"}, {}]
|
||||
else:
|
||||
self.logger.debug("Using more accurate float32 precision")
|
||||
fp_args_list = [{}]
|
||||
|
||||
vae = None
|
||||
deferred_error = None
|
||||
|
||||
# A VAE may be in a subfolder of a model's repository.
|
||||
if "subfolder" in vae_config:
|
||||
vae_args["subfolder"] = vae_config["subfolder"]
|
||||
|
||||
for fp_args in fp_args_list:
|
||||
# At some point we might need to be able to use different classes here? But for now I think
|
||||
# all Stable Diffusion VAE are AutoencoderKL.
|
||||
try:
|
||||
vae = AutoencoderKL.from_pretrained(name_or_path, **vae_args, **fp_args)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
deferred_error = e
|
||||
if vae:
|
||||
break
|
||||
|
||||
if not vae and deferred_error:
|
||||
self.logger.warning(f"Could not load VAE {name_or_path}: {str(deferred_error)}")
|
||||
|
||||
return vae
|
||||
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
@ -1326,8 +794,3 @@ class ModelManager(object):
|
||||
return path
|
||||
return Path(Globals.root, path).resolve()
|
||||
|
||||
@staticmethod
|
||||
def _is_huggingface_hub_directory_present() -> bool:
|
||||
return (
|
||||
os.getenv("HF_HOME") is not None or os.getenv("XDG_CACHE_HOME") is not None
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user