mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix model_cache memory management issues
This commit is contained in:
parent
c7de2b2801
commit
ab2b5a691d
@ -14,6 +14,7 @@ import sys
|
||||
import traceback
|
||||
import transformers
|
||||
import io
|
||||
import gc
|
||||
import hashlib
|
||||
import cv2
|
||||
import skimage
|
||||
@ -789,9 +790,20 @@ class Generate:
|
||||
if self.model_name == model_name and self.model is not None:
|
||||
return self.model
|
||||
|
||||
model_data = self.model_cache.get_model(model_name)
|
||||
if model_data is None or len(model_data) == 0:
|
||||
return None
|
||||
# the model cache does the loading and offloading
|
||||
cache = self.model_cache
|
||||
cache.print_vram_usage()
|
||||
|
||||
# have to get rid of all references to model in order
|
||||
# to free it from GPU memory
|
||||
self.model = None
|
||||
self.sampler = None
|
||||
self.generators = {}
|
||||
gc.collect()
|
||||
|
||||
model_data = cache.get_model(model_name)
|
||||
if model_data is None: # restore previous
|
||||
model_data = cache.get_model(self.model_name)
|
||||
|
||||
self.model = model_data['model']
|
||||
self.width = model_data['width']
|
||||
|
@ -52,7 +52,9 @@ class ModelCache(object):
|
||||
return None
|
||||
|
||||
if self.current_model != model_name:
|
||||
self.unload_model(self.current_model)
|
||||
if model_name not in self.models: # make room for a new one
|
||||
self._make_cache_room()
|
||||
self.offload_model(self.current_model)
|
||||
|
||||
if model_name in self.models:
|
||||
requested_model = self.models[model_name]['model']
|
||||
@ -61,8 +63,7 @@ class ModelCache(object):
|
||||
width = self.models[model_name]['width']
|
||||
height = self.models[model_name]['height']
|
||||
hash = self.models[model_name]['hash']
|
||||
else: # we're about to load a new model, so potentially unload the least recently used one
|
||||
self._check_cache_size()
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
try:
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {}
|
||||
@ -176,17 +177,6 @@ class ModelCache(object):
|
||||
self._invalidate_cached_model(model_name)
|
||||
return True
|
||||
|
||||
def _check_cache_size(self):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
else:
|
||||
print(f'>> Model will be cached in CPU')
|
||||
|
||||
def _load_model(self, model_name:str):
|
||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||
if model_name not in self.config:
|
||||
@ -258,16 +248,37 @@ class ModelCache(object):
|
||||
)
|
||||
return model, width, height, model_hash
|
||||
|
||||
def unload_model(self, model_name:str):
|
||||
def offload_model(self, model_name:str):
|
||||
'''
|
||||
Offload the indicated model to CPU. Will call
|
||||
_make_cache_room() to free space if needed.
|
||||
'''
|
||||
|
||||
if model_name not in self.models:
|
||||
return
|
||||
print(f'>> Unloading {model_name} from GPU')
|
||||
|
||||
message = f'>> Offloading {model_name} to CPU'
|
||||
print(message)
|
||||
model = self.models[model_name]['model']
|
||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||
|
||||
gc.collect()
|
||||
if self._has_cuda():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _make_cache_room(self):
|
||||
num_loaded_models = len(self.models)
|
||||
if num_loaded_models >= self.max_loaded_models:
|
||||
least_recent_model = self._pop_oldest_model()
|
||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||
if least_recent_model is not None:
|
||||
del self.models[least_recent_model]
|
||||
gc.collect()
|
||||
|
||||
def print_vram_usage(self):
|
||||
if self._has_cuda:
|
||||
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||
|
||||
def commit(self,config_file_path:str):
|
||||
'''
|
||||
Write current configuration out to the indicated file.
|
||||
@ -293,7 +304,7 @@ class ModelCache(object):
|
||||
'''
|
||||
|
||||
def _invalidate_cached_model(self,model_name:str):
|
||||
self.unload_model(model_name)
|
||||
self.offload_model(model_name)
|
||||
if model_name in self.stack:
|
||||
self.stack.remove(model_name)
|
||||
self.models.pop(model_name,None)
|
||||
|
Loading…
Reference in New Issue
Block a user