mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add max_load_models parameter for model cache control
- ldm.generate.Generator() now takes an argument named `max_load_models`. This is an integer that limits the model cache size. When the cache reaches the limit, it will start purging older models from cache. - CLI takes an argument --max_load_models, default to 2. This will keep one model in GPU and the other in CPU and switch back and forth quickly. - To not cache models at all, pass --max_load_models=1
This commit is contained in:
parent
3cdfedc649
commit
dc556cb1a7
@ -57,7 +57,7 @@ torch.bernoulli = fix_func(torch.bernoulli)
|
|||||||
torch.multinomial = fix_func(torch.multinomial)
|
torch.multinomial = fix_func(torch.multinomial)
|
||||||
|
|
||||||
# this is fallback model in case no default is defined
|
# this is fallback model in case no default is defined
|
||||||
FALLBACK_MODEL_NAME='stable-diffusion-1.4'
|
FALLBACK_MODEL_NAME='stable-diffusion-1.5'
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
@ -146,6 +146,7 @@ class Generate:
|
|||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
safety_checker:bool=False,
|
safety_checker:bool=False,
|
||||||
|
max_loaded_models:int=3,
|
||||||
# these are deprecated; if present they override values in the conf file
|
# these are deprecated; if present they override values in the conf file
|
||||||
weights = None,
|
weights = None,
|
||||||
config = None,
|
config = None,
|
||||||
@ -177,6 +178,7 @@ class Generate:
|
|||||||
self.codeformer = codeformer
|
self.codeformer = codeformer
|
||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
|
self.max_loaded_models = max_loaded_models,
|
||||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
self.txt2mask = None
|
self.txt2mask = None
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
@ -198,7 +200,7 @@ class Generate:
|
|||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_cache = ModelCache(mconfig,self.device,self.precision)
|
self.model_cache = ModelCache(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||||
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
self.model_name = model or self.model_cache.default_model() or FALLBACK_MODEL_NAME
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
|
@ -413,6 +413,13 @@ class Args(object):
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
help='Deprecated way to set --precision=float32',
|
help='Deprecated way to set --precision=float32',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--max_loaded_models',
|
||||||
|
dest='max_loaded_models',
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help='Maximum number of models to keep in memory for fast switching, including the one in GPU',
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--free_gpu_mem',
|
'--free_gpu_mem',
|
||||||
dest='free_gpu_mem',
|
dest='free_gpu_mem',
|
||||||
|
@ -20,12 +20,10 @@ from omegaconf import OmegaConf
|
|||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
GIGS=2**30
|
DEFAULT_MAX_MODELS=3
|
||||||
AVG_MODEL_SIZE=2.1*GIGS
|
|
||||||
DEFAULT_MIN_AVAIL=2*GIGS
|
|
||||||
|
|
||||||
class ModelCache(object):
|
class ModelCache(object):
|
||||||
def __init__(self, config:OmegaConf, device_type:str, precision:str, min_avail_mem=DEFAULT_MIN_AVAIL):
|
def __init__(self, config:OmegaConf, device_type:str, precision:str, max_loaded_models=DEFAULT_MAX_MODELS):
|
||||||
'''
|
'''
|
||||||
Initialize with the path to the models.yaml config file,
|
Initialize with the path to the models.yaml config file,
|
||||||
the torch device type, and precision. The optional
|
the torch device type, and precision. The optional
|
||||||
@ -38,7 +36,7 @@ class ModelCache(object):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.device = torch.device(device_type)
|
self.device = torch.device(device_type)
|
||||||
self.min_avail_mem = min_avail_mem
|
self.max_loaded_models = max_loaded_models
|
||||||
self.models = {}
|
self.models = {}
|
||||||
self.stack = [] # this is an LRU FIFO
|
self.stack = [] # this is an LRU FIFO
|
||||||
self.current_model = None
|
self.current_model = None
|
||||||
@ -63,8 +61,8 @@ class ModelCache(object):
|
|||||||
width = self.models[model_name]['width']
|
width = self.models[model_name]['width']
|
||||||
height = self.models[model_name]['height']
|
height = self.models[model_name]['height']
|
||||||
hash = self.models[model_name]['hash']
|
hash = self.models[model_name]['hash']
|
||||||
else:
|
else: # we're about to load a new model, so potentially unload the least recently used one
|
||||||
self._check_memory()
|
self._check_cache_size()
|
||||||
try:
|
try:
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name] = {}
|
self.models[model_name] = {}
|
||||||
@ -178,14 +176,16 @@ class ModelCache(object):
|
|||||||
self._invalidate_cached_model(model_name)
|
self._invalidate_cached_model(model_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _check_memory(self):
|
def _check_cache_size(self):
|
||||||
avail_memory = psutil.virtual_memory()[1]
|
num_loaded_models = len(self.models)
|
||||||
if AVG_MODEL_SIZE + self.min_avail_mem > avail_memory:
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
least_recent_model = self._pop_oldest_model()
|
least_recent_model = self._pop_oldest_model()
|
||||||
|
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||||
if least_recent_model is not None:
|
if least_recent_model is not None:
|
||||||
del self.models[least_recent_model]
|
del self.models[least_recent_model]
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
else:
|
||||||
|
print(f'>> Model will be cached in CPU')
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
def _load_model(self, model_name:str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
@ -261,7 +261,7 @@ class ModelCache(object):
|
|||||||
def unload_model(self, model_name:str):
|
def unload_model(self, model_name:str):
|
||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
print(f'>> Caching model {model_name} in system RAM')
|
print(f'>> Unloading {model_name} from GPU')
|
||||||
model = self.models[model_name]['model']
|
model = self.models[model_name]['model']
|
||||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -322,8 +322,7 @@ class ModelCache(object):
|
|||||||
to be the least recently accessed model. Do not
|
to be the least recently accessed model. Do not
|
||||||
pop the last one, because it is in active use!
|
pop the last one, because it is in active use!
|
||||||
'''
|
'''
|
||||||
if len(self.stack) > 1:
|
return self.stack.pop(0)
|
||||||
return self.stack.pop(0)
|
|
||||||
|
|
||||||
def _push_newest_model(self,model_name:str):
|
def _push_newest_model(self,model_name:str):
|
||||||
'''
|
'''
|
||||||
|
@ -38,6 +38,10 @@ def main():
|
|||||||
if args.weights:
|
if args.weights:
|
||||||
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
if args.max_loaded_models is not None:
|
||||||
|
if args.max_loaded_models <= 0:
|
||||||
|
print('--max_loaded_models must be >= 1; using 1')
|
||||||
|
args.max_loaded_models = 1
|
||||||
|
|
||||||
print('* Initializing, be patient...')
|
print('* Initializing, be patient...')
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
@ -81,6 +85,7 @@ def main():
|
|||||||
esrgan=esrgan,
|
esrgan=esrgan,
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
safety_checker=opt.safety_checker,
|
safety_checker=opt.safety_checker,
|
||||||
|
max_loaded_models=opt.max_loaded_models,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
|
Loading…
Reference in New Issue
Block a user