mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cap model cache size using bytes, not # models
This commit is contained in:
parent
647ffb2a0f
commit
667171ed90
@ -59,7 +59,7 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
||||
config.conf,
|
||||
precision=dtype,
|
||||
device_type=device,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
max_cache_size=config.max_cache_size,
|
||||
# temporarily disabled until model manager stabilizes
|
||||
# embedding_path = Path(embedding_path),
|
||||
logger = logger,
|
||||
|
@ -502,11 +502,11 @@ class Args(object):
|
||||
help="Deprecated way to set --precision=float32",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max_loaded_models",
|
||||
dest="max_loaded_models",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
||||
"--max_cache_size",
|
||||
dest="max_cache_size",
|
||||
type=float,
|
||||
default=6.0,
|
||||
help="Maximum size of the model RAM cache (in GB). 6 GB is sufficient to keep 2-3 diffusers models in RAM simultaneously.",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--free_gpu_mem",
|
||||
|
@ -19,11 +19,13 @@ context. Use like this:
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import logging
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union, Tuple, types
|
||||
from psutil import Process
|
||||
from typing import Dict, Sequence, Union, Tuple, types
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
@ -41,7 +43,12 @@ import invokeai.backend.util.logging as logger
|
||||
from ..globals import global_cache_dir
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
|
||||
MAX_MODELS = 4
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
|
||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||
class SDModelType(Enum):
|
||||
@ -65,6 +72,24 @@ class ModelStatus(Enum):
|
||||
in_ram='cached'
|
||||
in_vram='in gpu'
|
||||
active='locked in gpu'
|
||||
|
||||
# This is used to guesstimate the size of a model before we load it.
|
||||
# After loading, we will know it exactly.
|
||||
# Sizes are in Gigs, estimated for float16; double for float32
|
||||
SIZE_GUESSTIMATE = {
|
||||
SDModelType.diffusion_pipeline: 2.5,
|
||||
SDModelType.diffusers: 2.5,
|
||||
SDModelType.vae: 0.35,
|
||||
SDModelType.text_encoder: 0.5,
|
||||
SDModelType.tokenizer: 0.0001,
|
||||
SDModelType.unet: 3.4,
|
||||
SDModelType.scheduler: 0.0001,
|
||||
SDModelType.safety_checker: 1.2,
|
||||
SDModelType.feature_extractor: 0.0001,
|
||||
SDModelType.lora: 0.1,
|
||||
SDModelType.textual_inversion: 0.0001,
|
||||
SDModelType.ckpt: 4.2,
|
||||
}
|
||||
|
||||
# The list of model classes we know how to fetch, for typechecking
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
@ -90,7 +115,7 @@ class ModelLocker(object):
|
||||
class ModelCache(object):
|
||||
def __init__(
|
||||
self,
|
||||
max_models: int=MAX_MODELS,
|
||||
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||
execution_device: torch.device=torch.device('cuda'),
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
@ -113,13 +138,15 @@ class ModelCache(object):
|
||||
self.lazy_offloading = lazy_offloading
|
||||
self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_models: int=max_models
|
||||
self.current_cache_size: int=0
|
||||
self.max_cache_size: int=max_cache_size
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.logger = logger
|
||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||
self.model_sizes: Dict[str,int] = dict()
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
@ -172,21 +199,33 @@ class ModelCache(object):
|
||||
model_type.value,
|
||||
revision,
|
||||
subfolder
|
||||
)
|
||||
)
|
||||
|
||||
if key in self.models: # cached - move to bottom of stack
|
||||
with contextlib.suppress(ValueError):
|
||||
self.stack.remove(key)
|
||||
self.stack.append(key)
|
||||
model = self.models[key]
|
||||
|
||||
else: # not cached -load
|
||||
self._make_cache_room()
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}')
|
||||
|
||||
# this will remove older cached models until
|
||||
# there is sufficient room to load the requested model
|
||||
self._make_cache_room(key, model_type)
|
||||
|
||||
with MemoryUsage() as usage:
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
logger.debug(f'Actual memory used to load model: {(usage.mem_used/GIG):.2f} GB')
|
||||
self.model_sizes[key] = usage.mem_used
|
||||
self.current_cache_size += usage.mem_used
|
||||
|
||||
if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]:
|
||||
self.attach_part(model,*attach_model_part)
|
||||
self.stack.append(key) # add to LRU cache
|
||||
@ -200,11 +239,11 @@ class ModelCache(object):
|
||||
def uncache_model(self, key: str):
|
||||
'''Remove corresponding model from the cache'''
|
||||
if key is not None and key in self.models:
|
||||
with contextlib.suppress(ValueError):
|
||||
with contextlib.suppress(ValueError), contextlib.suppress(KeyError):
|
||||
del self.models[key]
|
||||
del self.locked_models[key]
|
||||
self.stack.remove(key)
|
||||
self.loaded_models.remove(key)
|
||||
self.stack.remove(key)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
@ -304,9 +343,9 @@ class ModelCache(object):
|
||||
else:
|
||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||
|
||||
def cache_size(self)->int:
|
||||
"Return the current number of models cached."
|
||||
return len(self.models)
|
||||
def cache_size(self)->float:
|
||||
"Return the current size of the cache, in GB"
|
||||
return self.current_cache_size / GIG
|
||||
|
||||
@classmethod
|
||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
||||
@ -342,18 +381,29 @@ class ModelCache(object):
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.current_cache_size / GIG)
|
||||
loaded_models = len(self.loaded_models)
|
||||
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
||||
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||
logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models:
|
||||
def _make_cache_room(self, key, model_type):
|
||||
# calculate how much memory this model will require
|
||||
multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = int(self.model_sizes.get(key,0) or SIZE_GUESSTIMATE[model_type]*GIG*multiplier)
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = self.current_cache_size
|
||||
|
||||
adjective = 'guesstimated' if key not in self.model_sizes else 'known from previous load'
|
||||
logger.debug(f'{(bytes_needed/GIG):.2f} GB needed to load this model ({adjective})')
|
||||
while current_size+bytes_needed > maximum_size:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
model_size = self.model_sizes.get(least_recently_used_key,0)
|
||||
logger.debug(f'Max cache size exceeded: cache_size={(current_size/GIG):.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||
logger.debug(f'Unloading model {least_recently_used_key} to free {(model_size/GIG):.2f} GB')
|
||||
self.uncache_model(least_recently_used_key)
|
||||
current_size -= model_size
|
||||
self.current_cache_size = current_size
|
||||
gc.collect()
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
@ -393,8 +443,8 @@ class ModelCache(object):
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
@ -411,7 +461,6 @@ class ModelCache(object):
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
'''
|
||||
self.logger.info(f'Loading model {repo_id_or_path}')
|
||||
revisions = [revision] if revision \
|
||||
else ['fp16','main'] if self.precision==torch.float16 \
|
||||
else ['main']
|
||||
@ -529,3 +578,15 @@ class SilenceWarnings(object):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
class MemoryUsage(object):
|
||||
def __init__(self):
|
||||
self.vms = None
|
||||
self.mem_used = 0
|
||||
|
||||
def __enter__(self):
|
||||
self.vms = Process().memory_info().vms
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.mem_used = Process().memory_info().vms - self.vms
|
||||
|
@ -141,7 +141,7 @@ class SDLegacyType(Enum):
|
||||
V2_v = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
MAX_CACHE_SIZE = 6.0 # GB
|
||||
|
||||
class ModelManager(object):
|
||||
"""
|
||||
@ -155,7 +155,7 @@ class ModelManager(object):
|
||||
config_path: Path,
|
||||
device_type: torch.device = CUDA_DEVICE,
|
||||
precision: torch.dtype = torch.float16,
|
||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
||||
max_cache_size=MAX_CACHE_SIZE,
|
||||
sequential_offload=False,
|
||||
logger: types.ModuleType = logger,
|
||||
):
|
||||
@ -168,7 +168,7 @@ class ModelManager(object):
|
||||
self.config_path = config_path
|
||||
self.config = OmegaConf.load(self.config_path)
|
||||
self.cache = ModelCache(
|
||||
max_models=max_loaded_models,
|
||||
max_cache_size=max_cache_size,
|
||||
execution_device = device_type,
|
||||
precision = precision,
|
||||
sequential_offload = sequential_offload,
|
||||
|
Loading…
Reference in New Issue
Block a user