From 25332093263ffbab1533236c1eb03ed6fd50b995 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 23 May 2023 03:48:22 +0300 Subject: [PATCH] Rewrite cache to weak references --- .../app/services/model_manager_service.py | 6 +- .../backend/model_management/model_cache.py | 156 ++++++++++-------- invokeai/backend/restoration/codeformer.py | 2 +- invokeai/backend/restoration/gfpgan.py | 2 +- 4 files changed, 95 insertions(+), 71 deletions(-) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index ee466849a0..36696affce 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -267,10 +267,10 @@ class ModelManagerService(ModelManagerServiceBase): logger.debug(f'config file={config_file}') device = torch.device(choose_torch_device()) - if config.precision == "auto": + precision = config.precision + if precision == "auto": precision = choose_precision(device) - dtype = torch.float32 if precision=='float32' \ - else torch.float16 + dtype = torch.float32 if precision == 'float32' else torch.float16 # this is transitional backward compatibility # support for the deprecated `max_loaded_models` diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 4c6616ef20..714efb2b28 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -17,6 +17,7 @@ context. Use like this: """ import contextlib +import weakref import gc import os import sys @@ -427,15 +428,15 @@ class ModelCache(object): pass class _CacheRecord: - model: Any + key: str size: int + cache: ModelCache _locks: int - _cache: ModelCache - def __init__(self, cache, model: Any, size: int): - self._cache = cache - self.model = model + def __init__(self, cache, key: Any, size: int): + self.key = key self.size = size + self.cache = cache self._locks = 0 def lock(self): @@ -451,8 +452,9 @@ class _CacheRecord: @property def loaded(self): - if self.model is not None and hasattr(self.model, "device"): - return self.model.device != self._cache.storage_device + model = self.cache._cached_models.get(self.key, None) + if model is not None and hasattr(model, "device"): + return model.device != self.cache.storage_device else: return False @@ -478,22 +480,23 @@ class ModelCache(object): :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' - max_cache_size = 9999 + #max_cache_size = 9999 execution_device = torch.device('cuda') - self.models: Dict[str, _CacheRecord] = dict() self.model_infos: Dict[str, ModelInfoBase] = dict() - self.stack: Sequence = list() self.lazy_offloading = lazy_offloading - self.sequential_offload: bool=sequential_offload + #self.sequential_offload: bool=sequential_offload self.precision: torch.dtype=precision - self.current_cache_size: int=0 self.max_cache_size: int=max_cache_size self.execution_device: torch.device=execution_device self.storage_device: torch.device=storage_device self.sha_chunksize=sha_chunksize self.logger = logger + self._cached_models = weakref.WeakValueDictionary() + self._cached_infos = weakref.WeakKeyDictionary() + self._cache_stack = list() + def get_key( self, model_path: str, @@ -546,8 +549,9 @@ class ModelCache(object): self, repo_id_or_path: Union[str, Path], model_type: SDModelType = SDModelType.Diffusers, - submodel: SDModelType = None, - revision: str = None, + submodel: Optional[SDModelType] = None, + revision: Optional[str] = None, + variant: Optional[str] = None, gpu_load: bool = True, ) -> Any: @@ -565,7 +569,9 @@ class ModelCache(object): submodel_type=submodel, ) - if key not in self.models: + # TODO: lock for no copies on simultaneous calls? + model = self._cached_models.get(key, None) + if model is None: self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}') # this will remove older cached models until @@ -574,56 +580,54 @@ class ModelCache(object): # clean memory to make MemoryUsage() more accurate gc.collect() - model_obj = model_info.get_model(submodel, torch_dtype=self.precision) + model = model_info.get_model(submodel, torch_dtype=self.precision) if mem_used := model_info.get_size(submodel): - logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') - self.current_cache_size += mem_used # increment size of the cache + self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') - self.models[key] = _CacheRecord(self, model_obj, mem_used) + self._cached_models[key] = model + self._cached_infos[model] = _CacheRecord(self, key, mem_used) with suppress(Exception): - self.stack.remove(key) - self.stack.append(key) + self._cache_stack.remove(model) + self._cache_stack.append(model) - return self.ModelLocker(self, key, self.models[key].model, gpu_load) - - def uncache_model(self, key: str): - '''Remove corresponding model from the cache''' - self.models.pop(key, None) - with contextlib.suppress(ValueError): - self.stack.remove(key) + return self.ModelLocker(self, key, model, gpu_load) class ModelLocker(object): def __init__(self, cache, key, model, gpu_load): self.gpu_load = gpu_load self.cache = cache self.key = key - # This will keep a copy of the model in RAM until the locker - # is garbage collected. Needs testing! self.model = model def __enter__(self) -> Any: if not hasattr(self.model, 'to'): return self.model - cache_entry = self.cache.models[self.key] + cache_entry = self.cache._cached_infos[self.model] # NOTE that the model has to have the to() method in order for this # code to move it into GPU! if self.gpu_load: cache_entry.lock() - - if self.cache.lazy_offloading: - self.cache._offload_unlocked_models() - - if self.model.device != self.cache.execution_device: - self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}') - with VRAMUsage() as mem: - self.model.to(self.cache.execution_device) # move into GPU - self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB') - - self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}') - self.cache._print_cuda_stats() + + try: + if self.cache.lazy_offloading: + self.cache._offload_unlocked_models() + + if self.model.device != self.cache.execution_device: + self.cache.logger.debug(f'Moving {self.key} into {self.cache.execution_device}') + with VRAMUsage() as mem: + self.model.to(self.cache.execution_device) # move into GPU + self.cache.logger.debug(f'GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB') + + self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}') + self.cache._print_cuda_stats() + + except: + cache_entry.unlock() + raise + # TODO: not fully understand # in the event that the caller wants the model in RAM, we @@ -637,11 +641,13 @@ class ModelCache(object): if not hasattr(self.model, 'to'): return - self.cache.models[self.key].unlock() + cache_entry = self.cache._cached_infos[self.model] + cache_entry.unlock() if not self.cache.lazy_offloading: self.cache._offload_unlocked_models() self.cache._print_cuda_stats() + def model_hash( self, repo_id_or_path: Union[str, Path], @@ -661,55 +667,73 @@ class ModelCache(object): def cache_size(self) -> float: "Return the current size of the cache, in GB" - return self.current_cache_size / GIG + current_cache_size = sum([m.size for m in self._cached_infos.values()]) + return current_cache_size / GIG def _has_cuda(self) -> bool: return self.execution_device.type == 'cuda' def _print_cuda_stats(self): vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) - ram = "%4.2fG" % (self.current_cache_size / GIG) + ram = "%4.2fG" % self.cache_size() + cached_models = 0 loaded_models = 0 locked_models = 0 - for cache_entry in self.models.values(): - if cache_entry.loaded: + for model_info in self._cached_infos.values(): + cached_models += 1 + if model_info.loaded: loaded_models += 1 - if cache_entry.locked: + if model_info.locked: locked_models += 1 - logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; locked_models/loaded_models = {locked_models}/{loaded_models}") + self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}") + def _make_cache_room(self, model_size): # calculate how much memory this model will require #multiplier = 2 if self.precision==torch.float32 else 1 bytes_needed = model_size maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = self.current_cache_size + current_size = sum([m.size for m in self._cached_infos.values()]) if current_size + bytes_needed > maximum_size: - logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB') + self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB') + + self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}") pos = 0 - while current_size + bytes_needed > maximum_size and current_size > 0 and len(self.stack) > 0 and pos < len(self.stack): - model_key = self.stack[pos] - cache_entry = self.models[model_key] - if not cache_entry.locked: - logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)') - self.uncache_model(model_key) # del self.stack[pos] - current_size -= cache_entry.size + while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + model = self._cache_stack[pos] + model_info = self._cached_infos[model] + + refs = sys.getrefcount(model) + + device = model.device if hasattr(model, "device") else None + self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}") + + # 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function + if not model_info.locked and refs <= 3: + self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)') + current_size -= model_info.size + del self._cache_stack[pos] + del model + del model_info + else: pos += 1 - self.current_cache_size = current_size gc.collect() + torch.cuda.empty_cache() + + self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}") + def _offload_unlocked_models(self): - for key in self.models.keys(): - cache_entry = self.models[key] - if not cache_entry.locked and cache_entry.loaded: - self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}') - cache_entry.model.to(self.storage_device) + for model, model_info in self._cached_infos.items(): + if not model_info.locked and model_info.loaded: + self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}') + model.to(self.storage_device) def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256() @@ -721,7 +745,7 @@ class ModelCache(object): hash = f.read() return hash - logger.debug(f'computing hash of model {path.name}') + self.logger.debug(f'computing hash of model {path.name}') for file in list(path.rglob("*.ckpt")) \ + list(path.rglob("*.safetensors")) \ + list(path.rglob("*.pth")): diff --git a/invokeai/backend/restoration/codeformer.py b/invokeai/backend/restoration/codeformer.py index b7073f8f8b..4844b6e4cb 100644 --- a/invokeai/backend/restoration/codeformer.py +++ b/invokeai/backend/restoration/codeformer.py @@ -24,7 +24,7 @@ class CodeFormerRestoration: self.codeformer_model_exists = self.model_path.exists() if not self.codeformer_model_exists: - logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path) + logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}") sys.path.append(os.path.abspath(codeformer_dir)) def process(self, image, strength, device, seed=None, fidelity=0.75): diff --git a/invokeai/backend/restoration/gfpgan.py b/invokeai/backend/restoration/gfpgan.py index 063feaa89a..1da60c2f51 100644 --- a/invokeai/backend/restoration/gfpgan.py +++ b/invokeai/backend/restoration/gfpgan.py @@ -18,7 +18,7 @@ class GFPGAN: self.gfpgan_model_exists = os.path.isfile(self.model_path) if not self.gfpgan_model_exists: - logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path) + logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}") return None def model_exists(self):