mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix model_cache memory management issues
This commit is contained in:
parent
c7de2b2801
commit
ab2b5a691d
@ -14,6 +14,7 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
import transformers
|
import transformers
|
||||||
import io
|
import io
|
||||||
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import cv2
|
import cv2
|
||||||
import skimage
|
import skimage
|
||||||
@ -789,9 +790,20 @@ class Generate:
|
|||||||
if self.model_name == model_name and self.model is not None:
|
if self.model_name == model_name and self.model is not None:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
model_data = self.model_cache.get_model(model_name)
|
# the model cache does the loading and offloading
|
||||||
if model_data is None or len(model_data) == 0:
|
cache = self.model_cache
|
||||||
return None
|
cache.print_vram_usage()
|
||||||
|
|
||||||
|
# have to get rid of all references to model in order
|
||||||
|
# to free it from GPU memory
|
||||||
|
self.model = None
|
||||||
|
self.sampler = None
|
||||||
|
self.generators = {}
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
model_data = cache.get_model(model_name)
|
||||||
|
if model_data is None: # restore previous
|
||||||
|
model_data = cache.get_model(self.model_name)
|
||||||
|
|
||||||
self.model = model_data['model']
|
self.model = model_data['model']
|
||||||
self.width = model_data['width']
|
self.width = model_data['width']
|
||||||
|
@ -52,7 +52,9 @@ class ModelCache(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if self.current_model != model_name:
|
if self.current_model != model_name:
|
||||||
self.unload_model(self.current_model)
|
if model_name not in self.models: # make room for a new one
|
||||||
|
self._make_cache_room()
|
||||||
|
self.offload_model(self.current_model)
|
||||||
|
|
||||||
if model_name in self.models:
|
if model_name in self.models:
|
||||||
requested_model = self.models[model_name]['model']
|
requested_model = self.models[model_name]['model']
|
||||||
@ -61,8 +63,7 @@ class ModelCache(object):
|
|||||||
width = self.models[model_name]['width']
|
width = self.models[model_name]['width']
|
||||||
height = self.models[model_name]['height']
|
height = self.models[model_name]['height']
|
||||||
hash = self.models[model_name]['hash']
|
hash = self.models[model_name]['hash']
|
||||||
else: # we're about to load a new model, so potentially unload the least recently used one
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||||
self._check_cache_size()
|
|
||||||
try:
|
try:
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name] = {}
|
self.models[model_name] = {}
|
||||||
@ -176,17 +177,6 @@ class ModelCache(object):
|
|||||||
self._invalidate_cached_model(model_name)
|
self._invalidate_cached_model(model_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _check_cache_size(self):
|
|
||||||
num_loaded_models = len(self.models)
|
|
||||||
if num_loaded_models >= self.max_loaded_models:
|
|
||||||
least_recent_model = self._pop_oldest_model()
|
|
||||||
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
|
||||||
if least_recent_model is not None:
|
|
||||||
del self.models[least_recent_model]
|
|
||||||
gc.collect()
|
|
||||||
else:
|
|
||||||
print(f'>> Model will be cached in CPU')
|
|
||||||
|
|
||||||
def _load_model(self, model_name:str):
|
def _load_model(self, model_name:str):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if model_name not in self.config:
|
if model_name not in self.config:
|
||||||
@ -258,16 +248,37 @@ class ModelCache(object):
|
|||||||
)
|
)
|
||||||
return model, width, height, model_hash
|
return model, width, height, model_hash
|
||||||
|
|
||||||
def unload_model(self, model_name:str):
|
def offload_model(self, model_name:str):
|
||||||
|
'''
|
||||||
|
Offload the indicated model to CPU. Will call
|
||||||
|
_make_cache_room() to free space if needed.
|
||||||
|
'''
|
||||||
|
|
||||||
if model_name not in self.models:
|
if model_name not in self.models:
|
||||||
return
|
return
|
||||||
print(f'>> Unloading {model_name} from GPU')
|
|
||||||
|
message = f'>> Offloading {model_name} to CPU'
|
||||||
|
print(message)
|
||||||
model = self.models[model_name]['model']
|
model = self.models[model_name]['model']
|
||||||
self.models[model_name]['model'] = self._model_to_cpu(model)
|
self.models[model_name]['model'] = self._model_to_cpu(model)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if self._has_cuda():
|
if self._has_cuda():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _make_cache_room(self):
|
||||||
|
num_loaded_models = len(self.models)
|
||||||
|
if num_loaded_models >= self.max_loaded_models:
|
||||||
|
least_recent_model = self._pop_oldest_model()
|
||||||
|
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
|
||||||
|
if least_recent_model is not None:
|
||||||
|
del self.models[least_recent_model]
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
def print_vram_usage(self):
|
||||||
|
if self._has_cuda:
|
||||||
|
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
|
||||||
|
|
||||||
def commit(self,config_file_path:str):
|
def commit(self,config_file_path:str):
|
||||||
'''
|
'''
|
||||||
Write current configuration out to the indicated file.
|
Write current configuration out to the indicated file.
|
||||||
@ -293,7 +304,7 @@ class ModelCache(object):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
def _invalidate_cached_model(self,model_name:str):
|
def _invalidate_cached_model(self,model_name:str):
|
||||||
self.unload_model(model_name)
|
self.offload_model(model_name)
|
||||||
if model_name in self.stack:
|
if model_name in self.stack:
|
||||||
self.stack.remove(model_name)
|
self.stack.remove(model_name)
|
||||||
self.models.pop(model_name,None)
|
self.models.pop(model_name,None)
|
||||||
|
Loading…
Reference in New Issue
Block a user