From 0c8f0e3386259afb8b95a5656bcc49ac74f5b309 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 31 Oct 2022 08:53:16 -0400 Subject: [PATCH] add max_load_models parameter for model cache control - ldm.generate.Generator() now takes an argument named `max_load_models`. This is an integer that limits the model cache size. When the cache reaches the limit, it will start purging older models from cache. - CLI takes an argument --max_load_models, default to 2. This will keep one model in GPU and the other in CPU and switch back and forth quickly. - To not cache models at all, pass --max_load_models=1 --- ldm/generate.py | 6 ++++-- ldm/invoke/args.py | 7 +++++++ ldm/invoke/model_cache.py | 27 +++++++++++++-------------- scripts/invoke.py | 5 +++++ 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 40a079dcef..56acd14738 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli) torch.multinomial = fix_func(torch.multinomial) # this is fallback model in case no default is defined -FALLBACK_MODEL_NAME='stable-diffusion-1.4' +FALLBACK_MODEL_NAME='stable-diffusion-1.5' """Simplified text to image API for stable diffusion/latent diffusion @@ -146,6 +146,7 @@ class Generate: esrgan=None, free_gpu_mem=False, safety_checker:bool=False, + max_loaded_models:int=3, # these are deprecated; if present they override values in the conf file weights = None, config = None, @@ -177,6 +178,7 @@ class Generate: self.codeformer = codeformer self.esrgan = esrgan self.free_gpu_mem = free_gpu_mem + self.max_loaded_models = max_loaded_models, self.size_matters = True # used to warn once about large image sizes and VRAM self.txt2mask = None self.safety_checker = None @@ -198,7 +200,7 @@ class Generate: self.precision = choose_precision(self.device) # model caching system for fast switching - self.model_cache = ModelCache(mconfig,self.device,self.precision) + self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME # for VRAM usage statistics diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index b509362644..ff590edc74 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -413,6 +413,13 @@ class Args(object): action='store_true', help='Deprecated way to set --precision=float32', ) + model_group.add_argument( + '--max_loaded_models', + dest='max_loaded_models', + type=int, + default=2, + help='Maximum number of models to keep in memory for fast switching, including the one in GPU', + ) model_group.add_argument( '--free_gpu_mem', dest='free_gpu_mem', diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index ff72ce951f..092fc7a7a2 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -20,12 +20,10 @@ from omegaconf import OmegaConf from omegaconf.errors import ConfigAttributeError from ldm.util import instantiate_from_config -GIGS=2**30 -AVG_MODEL_SIZE=2.1*GIGS -DEFAULT_MIN_AVAIL=2*GIGS +DEFAULT_MAX_MODELS=3 class ModelCache(object): - def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL): + def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS): ''' Initialize with the path to the models.yaml config file, the torch device type, and precision. The optional @@ -38,7 +36,7 @@ class ModelCache(object): self.config = config self.precision = precision self.device = torch.device(device_type) - self.min_avail_mem = min_avail_mem + self.max_loaded_models = max_loaded_models self.models = {} self.stack = [] # this is an LRU FIFO self.current_model = None @@ -63,8 +61,8 @@ class ModelCache(object): width = self.models[model_name]['width'] height = self.models[model_name]['height'] hash = self.models[model_name]['hash'] - else: - self._check_memory() + else: # we're about to load a new model, so potentially unload the least recently used one + self._check_cache_size() try: requested_model, width, height, hash = self._load_model(model_name) self.models[model_name] = {} @@ -178,14 +176,16 @@ class ModelCache(object): self._invalidate_cached_model(model_name) return True - def _check_memory(self): - avail_memory = psutil.virtual_memory()[1] - if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory: + def _check_cache_size(self): + num_loaded_models = len(self.models) + if num_loaded_models >= self.max_loaded_models: least_recent_model = self._pop_oldest_model() + print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}') if least_recent_model is not None: del self.models[least_recent_model] gc.collect() - + else: + print(f'>> Model will be cached in CPU') def _load_model(self, model_name:str): """Load and initialize the model from configuration variables passed at object creation time""" @@ -261,7 +261,7 @@ class ModelCache(object): def unload_model(self, model_name:str): if model_name not in self.models: return - print(f'>> Caching model {model_name} in system RAM') + print(f'>> Unloading {model_name} from GPU') model = self.models[model_name]['model'] self.models[model_name]['model'] = self._model_to_cpu(model) gc.collect() @@ -322,8 +322,7 @@ class ModelCache(object): to be the least recently accessed model. Do not pop the last one, because it is in active use! ''' - if len(self.stack) > 1: - return self.stack.pop(0) + return self.stack.pop(0) def _push_newest_model(self,model_name:str): ''' diff --git a/scripts/invoke.py b/scripts/invoke.py index 5094473d33..62e7b5dea1 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -38,6 +38,10 @@ def main(): if args.weights: print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.') sys.exit(-1) + if args.max_loaded_models is not None: + if args.max_loaded_models <= 0: + print('--max_loaded_models must be >= 1; using 1') + args.max_loaded_models = 1 print('* Initializing, be patient...') from ldm.generate import Generate @@ -81,6 +85,7 @@ def main(): esrgan=esrgan, free_gpu_mem=opt.free_gpu_mem, safety_checker=opt.safety_checker, + max_loaded_models=opt.max_loaded_models, ) except (FileNotFoundError, IOError, KeyError) as e: print(f'{e}. Aborting.')