From 8e419a4f9736a119e4b7d81b10870772be4319be Mon Sep 17 00:00:00 2001
From: Sergey Borisov <stalkek7779@yandex.ru>
Date: Tue, 23 May 2023 04:29:40 +0300
Subject: [PATCH] Revert weak references as can be done without it

---
 .../backend/model_management/model_cache.py   | 75 +++++++++----------
 1 file changed, 37 insertions(+), 38 deletions(-)

diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py
index 714efb2b28..5811223650 100644
--- a/invokeai/backend/model_management/model_cache.py
+++ b/invokeai/backend/model_management/model_cache.py
@@ -17,7 +17,6 @@ context. Use like this:
 """
 
 import contextlib
-import weakref
 import gc
 import os
 import sys
@@ -428,14 +427,14 @@ class ModelCache(object):
     pass
 
 class _CacheRecord:
-    key: str
     size: int
+    model: Any
     cache: ModelCache
     _locks: int
 
-    def __init__(self, cache, key: Any, size: int):
-        self.key = key
+    def __init__(self, cache, model: Any, size: int):
         self.size = size
+        self.model = model
         self.cache = cache
         self._locks = 0
 
@@ -452,9 +451,8 @@ class _CacheRecord:
 
     @property
     def loaded(self):
-        model = self.cache._cached_models.get(self.key, None)
-        if model is not None and hasattr(model, "device"):
-            return model.device != self.cache.storage_device
+        if self.model is not None and hasattr(self.model, "device"):
+            return self.model.device != self.cache.storage_device
         else:
             return False
     
@@ -493,8 +491,7 @@ class ModelCache(object):
         self.sha_chunksize=sha_chunksize
         self.logger = logger
 
-        self._cached_models = weakref.WeakValueDictionary()
-        self._cached_infos = weakref.WeakKeyDictionary()
+        self._cached_models = dict()
         self._cache_stack = list()
 
     def get_key(
@@ -570,8 +567,8 @@ class ModelCache(object):
         )
 
         # TODO: lock for no copies on simultaneous calls?
-        model = self._cached_models.get(key, None)
-        if model is None:
+        cache_entry = self._cached_models.get(key, None)
+        if cache_entry is None:
             self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
 
             # this will remove older cached models until
@@ -584,14 +581,14 @@ class ModelCache(object):
             if mem_used := model_info.get_size(submodel):
                 self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
 
-            self._cached_models[key] = model
-            self._cached_infos[model] = _CacheRecord(self, key, mem_used)
+            cache_entry = _CacheRecord(self, model, mem_used)
+            self._cached_models[key] = cache_entry
 
         with suppress(Exception):
-            self._cache_stack.remove(model)
-        self._cache_stack.append(model)
+            self._cache_stack.remove(key)
+        self._cache_stack.append(key)
 
-        return self.ModelLocker(self, key, model, gpu_load)
+        return self.ModelLocker(self, key, cache_entry.model, gpu_load)
 
     class ModelLocker(object):
         def __init__(self, cache, key, model, gpu_load):
@@ -604,7 +601,7 @@ class ModelCache(object):
             if not hasattr(self.model, 'to'):
                 return self.model
 
-            cache_entry = self.cache._cached_infos[self.model]
+            cache_entry = self.cache._cached_models[self.key]
 
             # NOTE that the model has to have the to() method in order for this
             # code to move it into GPU!
@@ -641,7 +638,7 @@ class ModelCache(object):
             if not hasattr(self.model, 'to'):
                 return
 
-            cache_entry = self.cache._cached_infos[self.model]
+            cache_entry = self.cache._cached_models[self.key]
             cache_entry.unlock()
             if not self.cache.lazy_offloading:
                 self.cache._offload_unlocked_models()
@@ -667,7 +664,7 @@ class ModelCache(object):
 
     def cache_size(self) -> float:
         "Return the current size of the cache, in GB"
-        current_cache_size = sum([m.size for m in self._cached_infos.values()])
+        current_cache_size = sum([m.size for m in self._cached_models.values()])
         return current_cache_size / GIG
 
     def _has_cuda(self) -> bool:
@@ -680,7 +677,7 @@ class ModelCache(object):
         cached_models = 0
         loaded_models = 0
         locked_models = 0
-        for model_info in self._cached_infos.values():
+        for model_info in self._cached_models.values():
             cached_models += 1
             if model_info.loaded:
                 loaded_models += 1
@@ -695,30 +692,32 @@ class ModelCache(object):
         #multiplier = 2 if self.precision==torch.float32 else 1
         bytes_needed = model_size
         maximum_size = self.max_cache_size * GIG  # stored in GB, convert to bytes
-        current_size = sum([m.size for m in self._cached_infos.values()])
+        current_size = sum([m.size for m in self._cached_models.values()])
 
         if current_size + bytes_needed > maximum_size:
             self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
 
-        self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}")
+        self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
 
         pos = 0
         while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
-            model = self._cache_stack[pos]
-            model_info = self._cached_infos[model]
+            model_key = self._cache_stack[pos]
+            cache_entry = self._cached_models[model_key]
 
-            refs = sys.getrefcount(model)
+            refs = sys.getrefcount(cache_entry.model)
 
-            device = model.device if hasattr(model, "device") else None
-            self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}")
+            device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
+            self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
 
-            # 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function
-            if not model_info.locked and refs <= 3:
-                self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)')
-                current_size -= model_info.size
+            # 2 refs:
+            # 1 from cache_entry
+            # 1 from getrefcount function
+            if not cache_entry.locked and refs <= 2:
+                self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
+                current_size -= cache_entry.size
                 del self._cache_stack[pos]
-                del model
-                del model_info
+                del self._cached_models[model_key]
+                del cache_entry
 
             else:
                 pos += 1
@@ -726,14 +725,14 @@ class ModelCache(object):
         gc.collect()
         torch.cuda.empty_cache()
 
-        self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}")
+        self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
 
 
     def _offload_unlocked_models(self):
-        for model, model_info in self._cached_infos.items():
-            if not model_info.locked and model_info.loaded:
-                self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}')
-                model.to(self.storage_device)
+        for model_key, cache_entry in self._cached_models.items():
+            if not cache_entry.locked and cache_entry.loaded:
+                self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
+                cache_entry.model.to(self.storage_device)
         
     def _local_model_hash(self, model_path: Union[str, Path]) -> str:
         sha = hashlib.sha256()