add max_load_models parameter for model cache control

- ldm.generate.Generator() now takes an argument named `max_load_models`.
  This is an integer that limits the model cache size. When the cache
  reaches the limit, it will start purging older models from cache.

- CLI takes an argument --max_load_models, default to 2. This will keep
  one model in GPU and the other in CPU and switch back and forth
  quickly.

- To not cache models at all, pass --max_load_models=1
This commit is contained in:
Lincoln Stein
2022-10-31 08:53:16 -04:00
parent 98f03053ba
commit 0c8f0e3386
4 changed files with 29 additions and 16 deletions

View File

@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
# this is fallback model in case no default is defined
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
"""Simplified text to image API for stable diffusion/latent diffusion
@ -146,6 +146,7 @@ class Generate:
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
max_loaded_models:int=3,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
@ -177,6 +178,7 @@ class Generate:
self.codeformer = codeformer
self.esrgan = esrgan
self.free_gpu_mem = free_gpu_mem
self.max_loaded_models = max_loaded_models,
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
@ -198,7 +200,7 @@ class Generate:
self.precision = choose_precision(self.device)
# model caching system for fast switching
self.model_cache = ModelCache(mconfig,self.device,self.precision)
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
# for VRAM usage statistics