add max_load_models parameter for model cache control

- ldm.generate.Generator() now takes an argument named `max_load_models`.
  This is an integer that limits the model cache size. When the cache
  reaches the limit, it will start purging older models from cache.

- CLI takes an argument --max_load_models, default to 2. This will keep
  one model in GPU and the other in CPU and switch back and forth
  quickly.

- To not cache models at all, pass --max_load_models=1
This commit is contained in:
Lincoln Stein
2022-10-31 08:53:16 -04:00
parent 98f03053ba
commit 0c8f0e3386
4 changed files with 29 additions and 16 deletions

View File

@ -38,6 +38,10 @@ def main():
if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
if args.max_loaded_models is not None:
if args.max_loaded_models <= 0:
print('--max_loaded_models must be >= 1; using 1')
args.max_loaded_models = 1
print('* Initializing, be patient...')
from ldm.generate import Generate
@ -81,6 +85,7 @@ def main():
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
max_loaded_models=opt.max_loaded_models,
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')