fix model_cache memory management issues

This commit is contained in:
Lincoln Stein 2022-11-01 17:22:48 -04:00
parent 1379642fc6
commit 942a202945
2 changed files with 43 additions and 20 deletions

View File

@ -14,6 +14,7 @@ import sys
import traceback
import transformers
import io
import gc
import hashlib
import cv2
import skimage
@ -789,9 +790,20 @@ class Generate:
if self.model_name == model_name and self.model is not None:
return self.model
model_data = self.model_cache.get_model(model_name)
if model_data is None or len(model_data) == 0:
return None
# the model cache does the loading and offloading
cache = self.model_cache
cache.print_vram_usage()
# have to get rid of all references to model in order
# to free it from GPU memory
self.model = None
self.sampler = None
self.generators = {}
gc.collect()
model_data = cache.get_model(model_name)
if model_data is None: # restore previous
model_data = cache.get_model(self.model_name)
self.model = model_data['model']
self.width = model_data['width']

View File

@ -52,7 +52,9 @@ class ModelCache(object):
return None
if self.current_model != model_name:
self.unload_model(self.current_model)
if model_name not in self.models: # make room for a new one
self._make_cache_room()
self.offload_model(self.current_model)
if model_name in self.models:
requested_model = self.models[model_name]['model']
@ -61,8 +63,7 @@ class ModelCache(object):
width = self.models[model_name]['width']
height = self.models[model_name]['height']
hash = self.models[model_name]['hash']
else: # we're about to load a new model, so potentially unload the least recently used one
self._check_cache_size()
else: # we're about to load a new model, so potentially offload the least recently used one
try:
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {}
@ -176,17 +177,6 @@ class ModelCache(object):
self._invalidate_cached_model(model_name)
return True
def _check_cache_size(self):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
else:
print(f'>> Model will be cached in CPU')
def _load_model(self, model_name:str):
"""Load and initialize the model from configuration variables passed at object creation time"""
if model_name not in self.config:
@ -258,16 +248,37 @@ class ModelCache(object):
)
return model, width, height, model_hash
def unload_model(self, model_name:str):
def offload_model(self, model_name:str):
'''
Offload the indicated model to CPU. Will call
_make_cache_room() to free space if needed.
'''
if model_name not in self.models:
return
print(f'>> Unloading {model_name} from GPU')
message = f'>> Offloading {model_name} to CPU'
print(message)
model = self.models[model_name]['model']
self.models[model_name]['model'] = self._model_to_cpu(model)
gc.collect()
if self._has_cuda():
torch.cuda.empty_cache()
def _make_cache_room(self):
num_loaded_models = len(self.models)
if num_loaded_models >= self.max_loaded_models:
least_recent_model = self._pop_oldest_model()
print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}')
if least_recent_model is not None:
del self.models[least_recent_model]
gc.collect()
def print_vram_usage(self):
if self._has_cuda:
print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9))
def commit(self,config_file_path:str):
'''
Write current configuration out to the indicated file.
@ -293,7 +304,7 @@ class ModelCache(object):
'''
def _invalidate_cached_model(self,model_name:str):
self.unload_model(model_name)
self.offload_model(model_name)
if model_name in self.stack:
self.stack.remove(model_name)
self.models.pop(model_name,None)