add max_load_models parameter for model cache control

- ldm.generate.Generator() now takes an argument named `max_load_models`.
  This is an integer that limits the model cache size. When the cache
  reaches the limit, it will start purging older models from cache.

- CLI takes an argument --max_load_models, default to 2. This will keep
  one model in GPU and the other in CPU and switch back and forth
  quickly.

- To not cache models at all, pass --max_load_models=1
This commit is contained in:
Lincoln Stein 2022-10-31 08:53:16 -04:00
parent 3cdfedc649
commit dc556cb1a7
4 changed files with 29 additions and 16 deletions

View File

@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
"""Simplified text to image API for stable diffusion/latent diffusion
@ -146,6 +146,7 @@ class Generate:
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
max_loaded_models:int=3,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
@ -177,6 +178,7 @@ class Generate:
self.codeformer = codeformer
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = max_loaded_models,
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
@ -198,7 +200,7 @@ class Generate:
self.precision = choose_precision(self.device)
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics

View File

@ -413,6 +413,13 @@ class Args(object):
action='store_true',
help='Deprecated way to set --precision=float32',
)
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
)
model_group.add_argument(
'--free_gpu_mem',
dest='free_gpu_mem',

View File

@ -20,12 +20,10 @@ from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config
GIGS=2**30
AVG_MODEL_SIZE=2.1*GIGS
DEFAULT_MIN_AVAIL=2*GIGS
DEFAULT_MAX_MODELS=3
class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
'''
Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional
@ -38,7 +36,7 @@ class ModelCache(object):
self.config = config
self.precision = precision
self.device = torch.device(device_type)
self.min_avail_mem = min_avail_mem
self.max_loaded_models = max_loaded_models
self.models = {}
self.stack = [] # this is an LRU FIFO
self.current_model = None
@ -63,8 +61,8 @@ class ModelCache(object):
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else:
self._check_memory()
else: # we're about to load a new model, so potentially unload the least recently used one
self._check_cache_size()
try:
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {}
@ -178,14 +176,16 @@ class ModelCache(object):
self._invalidate_cached_model(model_name)
return True
def _check_memory(self):
avail_memory = psutil.virtual_memory()[1]
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory:
def _check_cache_size(self):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
else:
print(f'>> Model will be cached in CPU')
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
@ -261,7 +261,7 @@ class ModelCache(object):
def unload_model(self, model_name:str):
if model_name not in self.models:
return
print(f'>> Caching model {model_name} in system RAM')
print(f'>> Unloading {model_name} from GPU')
model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect()
@ -322,8 +322,7 @@ class ModelCache(object):
to be the least recently accessed model. Do not
pop the last one, because it is in active use!
'''
if len(self.stack) > 1:
return self.stack.pop(0)
return self.stack.pop(0)
def _push_newest_model(self,model_name:str):
'''

View File

@ -38,6 +38,10 @@ def main():
if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
if args.max_loaded_models is not None:
if args.max_loaded_models <= 0:
print('--max_loaded_models must be >= 1; using 1')
args.max_loaded_models = 1
print('* Initializing, be patient...')
from ldm.generate import Generate
@ -81,6 +85,7 @@ def main():
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
max_loaded_models=opt.max_loaded_models,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')