add max_load_models parameter for model cache control

- ldm.generate.Generator() now takes an argument named `max_load_models`.
  This is an integer that limits the model cache size. When the cache
  reaches the limit, it will start purging older models from cache.

- CLI takes an argument --max_load_models, default to 2. This will keep
  one model in GPU and the other in CPU and switch back and forth
  quickly.

- To not cache models at all, pass --max_load_models=1
This commit is contained in:
Lincoln Stein 2022-10-31 08:53:16 -04:00
parent 3cdfedc649
commit dc556cb1a7
4 changed files with 29 additions and 16 deletions

View File

@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial) torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined # this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4' FALLBACK_MODEL_NAME='stable-diffusion-1.5'
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -146,6 +146,7 @@ class Generate:
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem=False,
safety_checker:bool=False, safety_checker:bool=False,
max_loaded_models:int=3,
# these are deprecated; if present they override values in the conf file # these are deprecated; if present they override values in the conf file
weights = None, weights = None,
config = None, config = None,
@ -177,6 +178,7 @@ class Generate:
self.codeformer = codeformer self.codeformer = codeformer
self.esrgan = esrgan self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = max_loaded_models,
self.size_matters = True # used to warn once about large image sizes and VRAM self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None self.txt2mask = None
self.safety_checker = None self.safety_checker = None
@ -198,7 +200,7 @@ class Generate:
self.precision = choose_precision(self.device) self.precision = choose_precision(self.device)
# model caching system for fast switching # model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision) self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics # for VRAM usage statistics

View File

@ -413,6 +413,13 @@ class Args(object):
action='store_true', action='store_true',
help='Deprecated way to set --precision=float32', help='Deprecated way to set --precision=float32',
) )
model_group.add_argument(
'--max_loaded_models',
dest='max_loaded_models',
type=int,
default=2,
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
)
model_group.add_argument( model_group.add_argument(
'--free_gpu_mem', '--free_gpu_mem',
dest='free_gpu_mem', dest='free_gpu_mem',

View File

@ -20,12 +20,10 @@ from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
GIGS=2**30 DEFAULT_MAX_MODELS=3
AVG_MODEL_SIZE=2.1*GIGS
DEFAULT_MIN_AVAIL=2*GIGS
class ModelCache(object): class ModelCache(object):
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL): def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
''' '''
Initialize with the path to the models.yaml config file, Initialize with the path to the models.yaml config file,
the torch device type, and precision. The optional the torch device type, and precision. The optional
@ -38,7 +36,7 @@ class ModelCache(object):
self.config = config self.config = config
self.precision = precision self.precision = precision
self.device = torch.device(device_type) self.device = torch.device(device_type)
self.min_avail_mem = min_avail_mem self.max_loaded_models = max_loaded_models
self.models = {} self.models = {}
self.stack = [] # this is an LRU FIFO self.stack = [] # this is an LRU FIFO
self.current_model = None self.current_model = None
@ -63,8 +61,8 @@ class ModelCache(object):
width = self.models[model_name]['width'] width = self.models[model_name]['width']
height = self.models[model_name]['height'] height = self.models[model_name]['height']
hash = self.models[model_name]['hash'] hash = self.models[model_name]['hash']
else: else: # we're about to load a new model, so potentially unload the least recently used one
self._check_memory() self._check_cache_size()
try: try:
requested_model, width, height, hash = self._load_model(model_name) requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {} self.models[model_name] = {}
@ -178,14 +176,16 @@ class ModelCache(object):
self._invalidate_cached_model(model_name) self._invalidate_cached_model(model_name)
return True return True
def _check_memory(self): def _check_cache_size(self):
avail_memory = psutil.virtual_memory()[1] num_loaded_models = len(self.models)
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory: if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model() least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None: if least_recent_model is not None:
del self.models[least_recent_model] del self.models[least_recent_model]
gc.collect() gc.collect()
else:
print(f'>> Model will be cached in CPU')
def _load_model(self, model_name:str): def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time""" """Load and initialize the model from configuration variables passed at object creation time"""
@ -261,7 +261,7 @@ class ModelCache(object):
def unload_model(self, model_name:str): def unload_model(self, model_name:str):
if model_name not in self.models: if model_name not in self.models:
return return
print(f'>> Caching model {model_name} in system RAM') print(f'>> Unloading {model_name} from GPU')
model = self.models[model_name]['model'] model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model) self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect() gc.collect()
@ -322,7 +322,6 @@ class ModelCache(object):
to be the least recently accessed model. Do not to be the least recently accessed model. Do not
pop the last one, because it is in active use! pop the last one, because it is in active use!
''' '''
if len(self.stack) > 1:
return self.stack.pop(0) return self.stack.pop(0)
def _push_newest_model(self,model_name:str): def _push_newest_model(self,model_name:str):

View File

@ -38,6 +38,10 @@ def main():
if args.weights: if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1) sys.exit(-1)
if args.max_loaded_models is not None:
if args.max_loaded_models <= 0:
print('--max_loaded_models must be >= 1; using 1')
args.max_loaded_models = 1
print('* Initializing, be patient...') print('* Initializing, be patient...')
from ldm.generate import Generate from ldm.generate import Generate
@ -81,6 +85,7 @@ def main():
esrgan=esrgan, esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem, free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker, safety_checker=opt.safety_checker,
max_loaded_models=opt.max_loaded_models,
) )
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')