diff --git a/ldm/generate.py b/ldm/generate.py index 2aec576606..223af95f0e 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -14,6 +14,7 @@ import sys import traceback import transformers import io +import gc import hashlib import cv2 import skimage @@ -789,9 +790,20 @@ class Generate: if self.model_name == model_name and self.model is not None: return self.model - model_data = self.model_cache.get_model(model_name) - if model_data is None or len(model_data) == 0: - return None + # the model cache does the loading and offloading + cache = self.model_cache + cache.print_vram_usage() + + # have to get rid of all references to model in order + # to free it from GPU memory + self.model = None + self.sampler = None + self.generators = {} + gc.collect() + + model_data = cache.get_model(model_name) + if model_data is None: # restore previous + model_data = cache.get_model(self.model_name) self.model = model_data['model'] self.width = model_data['width'] diff --git a/ldm/invoke/model_cache.py b/ldm/invoke/model_cache.py index 3380d453c3..7b434941df 100644 --- a/ldm/invoke/model_cache.py +++ b/ldm/invoke/model_cache.py @@ -52,7 +52,9 @@ class ModelCache(object): return None if self.current_model != model_name: - self.unload_model(self.current_model) + if model_name not in self.models: # make room for a new one + self._make_cache_room() + self.offload_model(self.current_model) if model_name in self.models: requested_model = self.models[model_name]['model'] @@ -61,8 +63,7 @@ class ModelCache(object): width = self.models[model_name]['width'] height = self.models[model_name]['height'] hash = self.models[model_name]['hash'] - else: # we're about to load a new model, so potentially unload the least recently used one - self._check_cache_size() + else: # we're about to load a new model, so potentially offload the least recently used one try: requested_model, width, height, hash = self._load_model(model_name) self.models[model_name] = {} @@ -176,17 +177,6 @@ class ModelCache(object): self._invalidate_cached_model(model_name) return True - def _check_cache_size(self): - num_loaded_models = len(self.models) - if num_loaded_models >= self.max_loaded_models: - least_recent_model = self._pop_oldest_model() - print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}') - if least_recent_model is not None: - del self.models[least_recent_model] - gc.collect() - else: - print(f'>> Model will be cached in CPU') - def _load_model(self, model_name:str): """Load and initialize the model from configuration variables passed at object creation time""" if model_name not in self.config: @@ -258,16 +248,37 @@ class ModelCache(object): ) return model, width, height, model_hash - def unload_model(self, model_name:str): + def offload_model(self, model_name:str): + ''' + Offload the indicated model to CPU. Will call + _make_cache_room() to free space if needed. + ''' + if model_name not in self.models: return - print(f'>> Unloading {model_name} from GPU') + + message = f'>> Offloading {model_name} to CPU' + print(message) model = self.models[model_name]['model'] self.models[model_name]['model'] = self._model_to_cpu(model) + gc.collect() if self._has_cuda(): torch.cuda.empty_cache() + def _make_cache_room(self): + num_loaded_models = len(self.models) + if num_loaded_models >= self.max_loaded_models: + least_recent_model = self._pop_oldest_model() + print(f'>> Cache limit (max={self.max_loaded_models}) reached. Purging {least_recent_model}') + if least_recent_model is not None: + del self.models[least_recent_model] + gc.collect() + + def print_vram_usage(self): + if self._has_cuda: + print ('>> Current VRAM usage: ','%4.2fG' % (torch.cuda.memory_allocated() / 1e9)) + def commit(self,config_file_path:str): ''' Write current configuration out to the indicated file. @@ -293,7 +304,7 @@ class ModelCache(object): ''' def _invalidate_cached_model(self,model_name:str): - self.unload_model(model_name) + self.offload_model(model_name) if model_name in self.stack: self.stack.remove(model_name) self.models.pop(model_name,None)