diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 714efb2b28..5811223650 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -17,7 +17,6 @@ context. Use like this: """ import contextlib -import weakref import gc import os import sys @@ -428,14 +427,14 @@ class ModelCache(object): pass class _CacheRecord: - key: str size: int + model: Any cache: ModelCache _locks: int - def __init__(self, cache, key: Any, size: int): - self.key = key + def __init__(self, cache, model: Any, size: int): self.size = size + self.model = model self.cache = cache self._locks = 0 @@ -452,9 +451,8 @@ class _CacheRecord: @property def loaded(self): - model = self.cache._cached_models.get(self.key, None) - if model is not None and hasattr(model, "device"): - return model.device != self.cache.storage_device + if self.model is not None and hasattr(self.model, "device"): + return self.model.device != self.cache.storage_device else: return False @@ -493,8 +491,7 @@ class ModelCache(object): self.sha_chunksize=sha_chunksize self.logger = logger - self._cached_models = weakref.WeakValueDictionary() - self._cached_infos = weakref.WeakKeyDictionary() + self._cached_models = dict() self._cache_stack = list() def get_key( @@ -570,8 +567,8 @@ class ModelCache(object): ) # TODO: lock for no copies on simultaneous calls? - model = self._cached_models.get(key, None) - if model is None: + cache_entry = self._cached_models.get(key, None) + if cache_entry is None: self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}') # this will remove older cached models until @@ -584,14 +581,14 @@ class ModelCache(object): if mem_used := model_info.get_size(submodel): self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') - self._cached_models[key] = model - self._cached_infos[model] = _CacheRecord(self, key, mem_used) + cache_entry = _CacheRecord(self, model, mem_used) + self._cached_models[key] = cache_entry with suppress(Exception): - self._cache_stack.remove(model) - self._cache_stack.append(model) + self._cache_stack.remove(key) + self._cache_stack.append(key) - return self.ModelLocker(self, key, model, gpu_load) + return self.ModelLocker(self, key, cache_entry.model, gpu_load) class ModelLocker(object): def __init__(self, cache, key, model, gpu_load): @@ -604,7 +601,7 @@ class ModelCache(object): if not hasattr(self.model, 'to'): return self.model - cache_entry = self.cache._cached_infos[self.model] + cache_entry = self.cache._cached_models[self.key] # NOTE that the model has to have the to() method in order for this # code to move it into GPU! @@ -641,7 +638,7 @@ class ModelCache(object): if not hasattr(self.model, 'to'): return - cache_entry = self.cache._cached_infos[self.model] + cache_entry = self.cache._cached_models[self.key] cache_entry.unlock() if not self.cache.lazy_offloading: self.cache._offload_unlocked_models() @@ -667,7 +664,7 @@ class ModelCache(object): def cache_size(self) -> float: "Return the current size of the cache, in GB" - current_cache_size = sum([m.size for m in self._cached_infos.values()]) + current_cache_size = sum([m.size for m in self._cached_models.values()]) return current_cache_size / GIG def _has_cuda(self) -> bool: @@ -680,7 +677,7 @@ class ModelCache(object): cached_models = 0 loaded_models = 0 locked_models = 0 - for model_info in self._cached_infos.values(): + for model_info in self._cached_models.values(): cached_models += 1 if model_info.loaded: loaded_models += 1 @@ -695,30 +692,32 @@ class ModelCache(object): #multiplier = 2 if self.precision==torch.float32 else 1 bytes_needed = model_size maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes - current_size = sum([m.size for m in self._cached_infos.values()]) + current_size = sum([m.size for m in self._cached_models.values()]) if current_size + bytes_needed > maximum_size: self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB') - self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}") + self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") pos = 0 while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): - model = self._cache_stack[pos] - model_info = self._cached_infos[model] + model_key = self._cache_stack[pos] + cache_entry = self._cached_models[model_key] - refs = sys.getrefcount(model) + refs = sys.getrefcount(cache_entry.model) - device = model.device if hasattr(model, "device") else None - self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}") + device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None + self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}") - # 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function - if not model_info.locked and refs <= 3: - self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)') - current_size -= model_info.size + # 2 refs: + # 1 from cache_entry + # 1 from getrefcount function + if not cache_entry.locked and refs <= 2: + self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)') + current_size -= cache_entry.size del self._cache_stack[pos] - del model - del model_info + del self._cached_models[model_key] + del cache_entry else: pos += 1 @@ -726,14 +725,14 @@ class ModelCache(object): gc.collect() torch.cuda.empty_cache() - self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}") + self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") def _offload_unlocked_models(self): - for model, model_info in self._cached_infos.items(): - if not model_info.locked and model_info.loaded: - self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}') - model.to(self.storage_device) + for model_key, cache_entry in self._cached_models.items(): + if not cache_entry.locked and cache_entry.loaded: + self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}') + cache_entry.model.to(self.storage_device) def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256()