mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add max_load_models parameter for model cache control
- ldm.generate.Generator() now takes an argument named `max_load_models`. This is an integer that limits the model cache size. When the cache reaches the limit, it will start purging older models from cache. - CLI takes an argument --max_load_models, default to 2. This will keep one model in GPU and the other in CPU and switch back and forth quickly. - To not cache models at all, pass --max_load_models=1
This commit is contained in:
parent
3cdfedc649
commit
dc556cb1a7
@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli)
|
||||
torch.multinomial = fix_func(torch.multinomial)
|
||||
|
||||
# this is fallback model in case no default is defined
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
||||
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
|
||||
|
||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||
|
||||
@ -146,6 +146,7 @@ class Generate:
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
safety_checker:bool=False,
|
||||
max_loaded_models:int=3,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
@ -177,6 +178,7 @@ class Generate:
|
||||
self.codeformer = codeformer
|
||||
self.esrgan = esrgan
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.max_loaded_models = max_loaded_models,
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
@ -198,7 +200,7 @@ class Generate:
|
||||
self.precision = choose_precision(self.device)
|
||||
|
||||
# model caching system for fast switching
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
||||
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||
|
||||
# for VRAM usage statistics
|
||||
|
@ -413,6 +413,13 @@ class Args(object):
|
||||
action='store_true',
|
||||
help='Deprecated way to set --precision=float32',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--max_loaded_models',
|
||||
dest='max_loaded_models',
|
||||
type=int,
|
||||
default=2,
|
||||
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--free_gpu_mem',
|
||||
dest='free_gpu_mem',
|
||||
|
@ -20,12 +20,10 @@ from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
GIGS=2**30
|
||||
AVG_MODEL_SIZE=2.1*GIGS
|
||||
DEFAULT_MIN_AVAIL=2*GIGS
|
||||
DEFAULT_MAX_MODELS=3
|
||||
|
||||
class ModelCache(object):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||
'''
|
||||
Initialize with the path to the models.yaml config file,
|
||||
the torch device type, and precision. The optional
|
||||
@ -38,7 +36,7 @@ class ModelCache(object):
|
||||
self.config = config
|
||||
self.precision = precision
|
||||
self.device = torch.device(device_type)
|
||||
self.min_avail_mem = min_avail_mem
|
||||
self.max_loaded_models = max_loaded_models
|
||||
self.models = {}
|
||||
self.stack = [] # this is an LRU FIFO
|
||||
self.current_model = None
|
||||
@ -63,8 +61,8 @@ class ModelCache(object):
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
else:
|
||||
self._check_memory()
|
||||
else: # we're about to load a new model, so potentially unload the least recently used one
|
||||
self._check_cache_size()
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
@ -178,14 +176,16 @@ class ModelCache(object):
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _check_memory(self):
|
||||
avail_memory = psutil.virtual_memory()[1]
|
||||
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory:
|
||||
def _check_cache_size(self):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
else:
|
||||
print(f'>> Model will be cached in CPU')
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
@ -261,7 +261,7 @@ class ModelCache(object):
|
||||
def unload_model(self, model_name:str):
|
||||
if model_name not in self.models:
|
||||
return
|
||||
print(f'>> Caching model {model_name} in system RAM')
|
||||
print(f'>> Unloading {model_name} from GPU')
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
gc.collect()
|
||||
@ -322,8 +322,7 @@ class ModelCache(object):
|
||||
to be the least recently accessed model. Do not
|
||||
pop the last one, because it is in active use!
|
||||
'''
|
||||
if len(self.stack) > 1:
|
||||
return self.stack.pop(0)
|
||||
return self.stack.pop(0)
|
||||
|
||||
def _push_newest_model(self,model_name:str):
|
||||
'''
|
||||
|
@ -38,6 +38,10 @@ def main():
|
||||
if args.weights:
|
||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||
sys.exit(-1)
|
||||
if args.max_loaded_models is not None:
|
||||
if args.max_loaded_models <= 0:
|
||||
print('--max_loaded_models must be >= 1; using 1')
|
||||
args.max_loaded_models = 1
|
||||
|
||||
print('* Initializing, be patient...')
|
||||
from ldm.generate import Generate
|
||||
@ -81,6 +85,7 @@ def main():
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
safety_checker=opt.safety_checker,
|
||||
max_loaded_models=opt.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
|
Loading…
Reference in New Issue
Block a user