mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cap model cache size using bytes, not # models
This commit is contained in:
parent
647ffb2a0f
commit
667171ed90
@ -59,7 +59,7 @@ def get_model_manager(config: Args, logger: types.ModuleType) -> ModelManager:
|
|||||||
config.conf,
|
config.conf,
|
||||||
precision=dtype,
|
precision=dtype,
|
||||||
device_type=device,
|
device_type=device,
|
||||||
max_loaded_models=config.max_loaded_models,
|
max_cache_size=config.max_cache_size,
|
||||||
# temporarily disabled until model manager stabilizes
|
# temporarily disabled until model manager stabilizes
|
||||||
# embedding_path = Path(embedding_path),
|
# embedding_path = Path(embedding_path),
|
||||||
logger = logger,
|
logger = logger,
|
||||||
|
@ -502,11 +502,11 @@ class Args(object):
|
|||||||
help="Deprecated way to set --precision=float32",
|
help="Deprecated way to set --precision=float32",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--max_loaded_models",
|
"--max_cache_size",
|
||||||
dest="max_loaded_models",
|
dest="max_cache_size",
|
||||||
type=int,
|
type=float,
|
||||||
default=2,
|
default=6.0,
|
||||||
help="Maximum number of models to keep in memory for fast switching, including the one in GPU",
|
help="Maximum size of the model RAM cache (in GB). 6 GB is sufficient to keep 2-3 diffusers models in RAM simultaneously.",
|
||||||
)
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
"--free_gpu_mem",
|
"--free_gpu_mem",
|
||||||
|
@ -19,11 +19,13 @@ context. Use like this:
|
|||||||
import contextlib
|
import contextlib
|
||||||
import gc
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence, Union, Tuple, types
|
from psutil import Process
|
||||||
|
from typing import Dict, Sequence, Union, Tuple, types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
@ -41,7 +43,12 @@ import invokeai.backend.util.logging as logger
|
|||||||
from ..globals import global_cache_dir
|
from ..globals import global_cache_dir
|
||||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
MAX_MODELS = 4
|
# Maximum size of the cache, in gigs
|
||||||
|
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||||
|
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||||
|
|
||||||
|
# actual size of a gig
|
||||||
|
GIG = 1073741824
|
||||||
|
|
||||||
# This is the mapping from the stable diffusion submodel dict key to the class
|
# This is the mapping from the stable diffusion submodel dict key to the class
|
||||||
class SDModelType(Enum):
|
class SDModelType(Enum):
|
||||||
@ -65,6 +72,24 @@ class ModelStatus(Enum):
|
|||||||
in_ram='cached'
|
in_ram='cached'
|
||||||
in_vram='in gpu'
|
in_vram='in gpu'
|
||||||
active='locked in gpu'
|
active='locked in gpu'
|
||||||
|
|
||||||
|
# This is used to guesstimate the size of a model before we load it.
|
||||||
|
# After loading, we will know it exactly.
|
||||||
|
# Sizes are in Gigs, estimated for float16; double for float32
|
||||||
|
SIZE_GUESSTIMATE = {
|
||||||
|
SDModelType.diffusion_pipeline: 2.5,
|
||||||
|
SDModelType.diffusers: 2.5,
|
||||||
|
SDModelType.vae: 0.35,
|
||||||
|
SDModelType.text_encoder: 0.5,
|
||||||
|
SDModelType.tokenizer: 0.0001,
|
||||||
|
SDModelType.unet: 3.4,
|
||||||
|
SDModelType.scheduler: 0.0001,
|
||||||
|
SDModelType.safety_checker: 1.2,
|
||||||
|
SDModelType.feature_extractor: 0.0001,
|
||||||
|
SDModelType.lora: 0.1,
|
||||||
|
SDModelType.textual_inversion: 0.0001,
|
||||||
|
SDModelType.ckpt: 4.2,
|
||||||
|
}
|
||||||
|
|
||||||
# The list of model classes we know how to fetch, for typechecking
|
# The list of model classes we know how to fetch, for typechecking
|
||||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||||
@ -90,7 +115,7 @@ class ModelLocker(object):
|
|||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_models: int=MAX_MODELS,
|
max_cache_size: float=DEFAULT_MAX_CACHE_SIZE,
|
||||||
execution_device: torch.device=torch.device('cuda'),
|
execution_device: torch.device=torch.device('cuda'),
|
||||||
storage_device: torch.device=torch.device('cpu'),
|
storage_device: torch.device=torch.device('cpu'),
|
||||||
precision: torch.dtype=torch.float16,
|
precision: torch.dtype=torch.float16,
|
||||||
@ -113,13 +138,15 @@ class ModelCache(object):
|
|||||||
self.lazy_offloading = lazy_offloading
|
self.lazy_offloading = lazy_offloading
|
||||||
self.sequential_offload: bool=sequential_offload
|
self.sequential_offload: bool=sequential_offload
|
||||||
self.precision: torch.dtype=precision
|
self.precision: torch.dtype=precision
|
||||||
self.max_models: int=max_models
|
self.current_cache_size: int=0
|
||||||
|
self.max_cache_size: int=max_cache_size
|
||||||
self.execution_device: torch.device=execution_device
|
self.execution_device: torch.device=execution_device
|
||||||
self.storage_device: torch.device=storage_device
|
self.storage_device: torch.device=storage_device
|
||||||
self.sha_chunksize=sha_chunksize
|
self.sha_chunksize=sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||||
|
self.model_sizes: Dict[str,int] = dict()
|
||||||
|
|
||||||
def get_model(
|
def get_model(
|
||||||
self,
|
self,
|
||||||
@ -172,21 +199,33 @@ class ModelCache(object):
|
|||||||
model_type.value,
|
model_type.value,
|
||||||
revision,
|
revision,
|
||||||
subfolder
|
subfolder
|
||||||
)
|
)
|
||||||
|
|
||||||
if key in self.models: # cached - move to bottom of stack
|
if key in self.models: # cached - move to bottom of stack
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
self.stack.remove(key)
|
self.stack.remove(key)
|
||||||
self.stack.append(key)
|
self.stack.append(key)
|
||||||
model = self.models[key]
|
model = self.models[key]
|
||||||
|
|
||||||
else: # not cached -load
|
else: # not cached -load
|
||||||
self._make_cache_room()
|
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}')
|
||||||
model = self._load_model_from_storage(
|
|
||||||
repo_id_or_path=repo_id_or_path,
|
# this will remove older cached models until
|
||||||
model_class=model_type.value,
|
# there is sufficient room to load the requested model
|
||||||
subfolder=subfolder,
|
self._make_cache_room(key, model_type)
|
||||||
revision=revision,
|
|
||||||
legacy_info=legacy_info,
|
with MemoryUsage() as usage:
|
||||||
)
|
model = self._load_model_from_storage(
|
||||||
|
repo_id_or_path=repo_id_or_path,
|
||||||
|
model_class=model_type.value,
|
||||||
|
subfolder=subfolder,
|
||||||
|
revision=revision,
|
||||||
|
legacy_info=legacy_info,
|
||||||
|
)
|
||||||
|
logger.debug(f'Actual memory used to load model: {(usage.mem_used/GIG):.2f} GB')
|
||||||
|
self.model_sizes[key] = usage.mem_used
|
||||||
|
self.current_cache_size += usage.mem_used
|
||||||
|
|
||||||
if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]:
|
if model_type==SDModelType.diffusion_pipeline and attach_model_part[0]:
|
||||||
self.attach_part(model,*attach_model_part)
|
self.attach_part(model,*attach_model_part)
|
||||||
self.stack.append(key) # add to LRU cache
|
self.stack.append(key) # add to LRU cache
|
||||||
@ -200,11 +239,11 @@ class ModelCache(object):
|
|||||||
def uncache_model(self, key: str):
|
def uncache_model(self, key: str):
|
||||||
'''Remove corresponding model from the cache'''
|
'''Remove corresponding model from the cache'''
|
||||||
if key is not None and key in self.models:
|
if key is not None and key in self.models:
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError), contextlib.suppress(KeyError):
|
||||||
del self.models[key]
|
del self.models[key]
|
||||||
del self.locked_models[key]
|
del self.locked_models[key]
|
||||||
self.stack.remove(key)
|
|
||||||
self.loaded_models.remove(key)
|
self.loaded_models.remove(key)
|
||||||
|
self.stack.remove(key)
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load):
|
def __init__(self, cache, key, model, gpu_load):
|
||||||
@ -304,9 +343,9 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
return self._hf_commit_hash(repo_id_or_path,revision)
|
return self._hf_commit_hash(repo_id_or_path,revision)
|
||||||
|
|
||||||
def cache_size(self)->int:
|
def cache_size(self)->float:
|
||||||
"Return the current number of models cached."
|
"Return the current size of the cache, in GB"
|
||||||
return len(self.models)
|
return self.current_cache_size / GIG
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
def is_legacy_ckpt(cls, repo_id_or_path: Union[str,Path])->bool:
|
||||||
@ -342,18 +381,29 @@ class ModelCache(object):
|
|||||||
return self.execution_device.type == 'cuda'
|
return self.execution_device.type == 'cuda'
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
def _print_cuda_stats(self):
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
|
ram = "%4.2fG" % (self.current_cache_size / GIG)
|
||||||
loaded_models = len(self.loaded_models)
|
loaded_models = len(self.loaded_models)
|
||||||
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
||||||
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||||
|
|
||||||
def _make_cache_room(self):
|
def _make_cache_room(self, key, model_type):
|
||||||
models_in_ram = len(self.models)
|
# calculate how much memory this model will require
|
||||||
while models_in_ram >= self.max_models:
|
multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
|
bytes_needed = int(self.model_sizes.get(key,0) or SIZE_GUESSTIMATE[model_type]*GIG*multiplier)
|
||||||
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
|
current_size = self.current_cache_size
|
||||||
|
|
||||||
|
adjective = 'guesstimated' if key not in self.model_sizes else 'known from previous load'
|
||||||
|
logger.debug(f'{(bytes_needed/GIG):.2f} GB needed to load this model ({adjective})')
|
||||||
|
while current_size+bytes_needed > maximum_size:
|
||||||
if least_recently_used_key := self.stack.pop(0):
|
if least_recently_used_key := self.stack.pop(0):
|
||||||
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
model_size = self.model_sizes.get(least_recently_used_key,0)
|
||||||
del self.models[least_recently_used_key]
|
logger.debug(f'Max cache size exceeded: cache_size={(current_size/GIG):.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||||
models_in_ram = len(self.models)
|
logger.debug(f'Unloading model {least_recently_used_key} to free {(model_size/GIG):.2f} GB')
|
||||||
|
self.uncache_model(least_recently_used_key)
|
||||||
|
current_size -= model_size
|
||||||
|
self.current_cache_size = current_size
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
def _offload_unlocked_models(self):
|
def _offload_unlocked_models(self):
|
||||||
@ -393,8 +443,8 @@ class ModelCache(object):
|
|||||||
revision,
|
revision,
|
||||||
model_class,
|
model_class,
|
||||||
)
|
)
|
||||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||||
model.enable_offload_submodels(self.execution_device)
|
model.enable_offload_submodels(self.execution_device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_diffusers_from_storage(
|
def _load_diffusers_from_storage(
|
||||||
@ -411,7 +461,6 @@ class ModelCache(object):
|
|||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||||
'''
|
'''
|
||||||
self.logger.info(f'Loading model {repo_id_or_path}')
|
|
||||||
revisions = [revision] if revision \
|
revisions = [revision] if revision \
|
||||||
else ['fp16','main'] if self.precision==torch.float16 \
|
else ['fp16','main'] if self.precision==torch.float16 \
|
||||||
else ['main']
|
else ['main']
|
||||||
@ -529,3 +578,15 @@ class SilenceWarnings(object):
|
|||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
warnings.simplefilter('default')
|
warnings.simplefilter('default')
|
||||||
|
|
||||||
|
class MemoryUsage(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.vms = None
|
||||||
|
self.mem_used = 0
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.vms = Process().memory_info().vms
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.mem_used = Process().memory_info().vms - self.vms
|
||||||
|
@ -141,7 +141,7 @@ class SDLegacyType(Enum):
|
|||||||
V2_v = auto()
|
V2_v = auto()
|
||||||
UNKNOWN = auto()
|
UNKNOWN = auto()
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
MAX_CACHE_SIZE = 6.0 # GB
|
||||||
|
|
||||||
class ModelManager(object):
|
class ModelManager(object):
|
||||||
"""
|
"""
|
||||||
@ -155,7 +155,7 @@ class ModelManager(object):
|
|||||||
config_path: Path,
|
config_path: Path,
|
||||||
device_type: torch.device = CUDA_DEVICE,
|
device_type: torch.device = CUDA_DEVICE,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
max_loaded_models=DEFAULT_MAX_MODELS,
|
max_cache_size=MAX_CACHE_SIZE,
|
||||||
sequential_offload=False,
|
sequential_offload=False,
|
||||||
logger: types.ModuleType = logger,
|
logger: types.ModuleType = logger,
|
||||||
):
|
):
|
||||||
@ -168,7 +168,7 @@ class ModelManager(object):
|
|||||||
self.config_path = config_path
|
self.config_path = config_path
|
||||||
self.config = OmegaConf.load(self.config_path)
|
self.config = OmegaConf.load(self.config_path)
|
||||||
self.cache = ModelCache(
|
self.cache = ModelCache(
|
||||||
max_models=max_loaded_models,
|
max_cache_size=max_cache_size,
|
||||||
execution_device = device_type,
|
execution_device = device_type,
|
||||||
precision = precision,
|
precision = precision,
|
||||||
sequential_offload = sequential_offload,
|
sequential_offload = sequential_offload,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user