mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Revert weak references as can be done without it
This commit is contained in:
parent
2533209326
commit
8e419a4f97
@ -17,7 +17,6 @@ context. Use like this:
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import weakref
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
@ -428,14 +427,14 @@ class ModelCache(object):
|
||||
pass
|
||||
|
||||
class _CacheRecord:
|
||||
key: str
|
||||
size: int
|
||||
model: Any
|
||||
cache: ModelCache
|
||||
_locks: int
|
||||
|
||||
def __init__(self, cache, key: Any, size: int):
|
||||
self.key = key
|
||||
def __init__(self, cache, model: Any, size: int):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.cache = cache
|
||||
self._locks = 0
|
||||
|
||||
@ -452,9 +451,8 @@ class _CacheRecord:
|
||||
|
||||
@property
|
||||
def loaded(self):
|
||||
model = self.cache._cached_models.get(self.key, None)
|
||||
if model is not None and hasattr(model, "device"):
|
||||
return model.device != self.cache.storage_device
|
||||
if self.model is not None and hasattr(self.model, "device"):
|
||||
return self.model.device != self.cache.storage_device
|
||||
else:
|
||||
return False
|
||||
|
||||
@ -493,8 +491,7 @@ class ModelCache(object):
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.logger = logger
|
||||
|
||||
self._cached_models = weakref.WeakValueDictionary()
|
||||
self._cached_infos = weakref.WeakKeyDictionary()
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
def get_key(
|
||||
@ -570,8 +567,8 @@ class ModelCache(object):
|
||||
)
|
||||
|
||||
# TODO: lock for no copies on simultaneous calls?
|
||||
model = self._cached_models.get(key, None)
|
||||
if model is None:
|
||||
cache_entry = self._cached_models.get(key, None)
|
||||
if cache_entry is None:
|
||||
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
|
||||
|
||||
# this will remove older cached models until
|
||||
@ -584,14 +581,14 @@ class ModelCache(object):
|
||||
if mem_used := model_info.get_size(submodel):
|
||||
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||
|
||||
self._cached_models[key] = model
|
||||
self._cached_infos[model] = _CacheRecord(self, key, mem_used)
|
||||
cache_entry = _CacheRecord(self, model, mem_used)
|
||||
self._cached_models[key] = cache_entry
|
||||
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(model)
|
||||
self._cache_stack.append(model)
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
|
||||
return self.ModelLocker(self, key, model, gpu_load)
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load):
|
||||
@ -604,7 +601,7 @@ class ModelCache(object):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return self.model
|
||||
|
||||
cache_entry = self.cache._cached_infos[self.model]
|
||||
cache_entry = self.cache._cached_models[self.key]
|
||||
|
||||
# NOTE that the model has to have the to() method in order for this
|
||||
# code to move it into GPU!
|
||||
@ -641,7 +638,7 @@ class ModelCache(object):
|
||||
if not hasattr(self.model, 'to'):
|
||||
return
|
||||
|
||||
cache_entry = self.cache._cached_infos[self.model]
|
||||
cache_entry = self.cache._cached_models[self.key]
|
||||
cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
@ -667,7 +664,7 @@ class ModelCache(object):
|
||||
|
||||
def cache_size(self) -> float:
|
||||
"Return the current size of the cache, in GB"
|
||||
current_cache_size = sum([m.size for m in self._cached_infos.values()])
|
||||
current_cache_size = sum([m.size for m in self._cached_models.values()])
|
||||
return current_cache_size / GIG
|
||||
|
||||
def _has_cuda(self) -> bool:
|
||||
@ -680,7 +677,7 @@ class ModelCache(object):
|
||||
cached_models = 0
|
||||
loaded_models = 0
|
||||
locked_models = 0
|
||||
for model_info in self._cached_infos.values():
|
||||
for model_info in self._cached_models.values():
|
||||
cached_models += 1
|
||||
if model_info.loaded:
|
||||
loaded_models += 1
|
||||
@ -695,30 +692,32 @@ class ModelCache(object):
|
||||
#multiplier = 2 if self.precision==torch.float32 else 1
|
||||
bytes_needed = model_size
|
||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||
current_size = sum([m.size for m in self._cached_infos.values()])
|
||||
current_size = sum([m.size for m in self._cached_models.values()])
|
||||
|
||||
if current_size + bytes_needed > maximum_size:
|
||||
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}")
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model = self._cache_stack[pos]
|
||||
model_info = self._cached_infos[model]
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(model)
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
device = model.device if hasattr(model, "device") else None
|
||||
self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}")
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
|
||||
|
||||
# 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function
|
||||
if not model_info.locked and refs <= 3:
|
||||
self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)')
|
||||
current_size -= model_info.size
|
||||
# 2 refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
if not cache_entry.locked and refs <= 2:
|
||||
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
||||
current_size -= cache_entry.size
|
||||
del self._cache_stack[pos]
|
||||
del model
|
||||
del model_info
|
||||
del self._cached_models[model_key]
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
pos += 1
|
||||
@ -726,14 +725,14 @@ class ModelCache(object):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}")
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
for model, model_info in self._cached_infos.items():
|
||||
if not model_info.locked and model_info.loaded:
|
||||
self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}')
|
||||
model.to(self.storage_device)
|
||||
for model_key, cache_entry in self._cached_models.items():
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
|
||||
cache_entry.model.to(self.storage_device)
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
Loading…
Reference in New Issue
Block a user