Revert weak references as can be done without it

This commit is contained in:
Sergey Borisov 2023-05-23 04:29:40 +03:00
parent 2533209326
commit 8e419a4f97

View File

@ -17,7 +17,6 @@ context. Use like this:
"""
import contextlib
import weakref
import gc
import os
import sys
@ -428,14 +427,14 @@ class ModelCache(object):
pass
class _CacheRecord:
key: str
size: int
model: Any
cache: ModelCache
_locks: int
def __init__(self, cache, key: Any, size: int):
self.key = key
def __init__(self, cache, model: Any, size: int):
self.size = size
self.model = model
self.cache = cache
self._locks = 0
@ -452,9 +451,8 @@ class _CacheRecord:
@property
def loaded(self):
model = self.cache._cached_models.get(self.key, None)
if model is not None and hasattr(model, "device"):
return model.device != self.cache.storage_device
if self.model is not None and hasattr(self.model, "device"):
return self.model.device != self.cache.storage_device
else:
return False
@ -493,8 +491,7 @@ class ModelCache(object):
self.sha_chunksize=sha_chunksize
self.logger = logger
self._cached_models = weakref.WeakValueDictionary()
self._cached_infos = weakref.WeakKeyDictionary()
self._cached_models = dict()
self._cache_stack = list()
def get_key(
@ -570,8 +567,8 @@ class ModelCache(object):
)
# TODO: lock for no copies on simultaneous calls?
model = self._cached_models.get(key, None)
if model is None:
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
# this will remove older cached models until
@ -584,14 +581,14 @@ class ModelCache(object):
if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
self._cached_models[key] = model
self._cached_infos[model] = _CacheRecord(self, key, mem_used)
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
with suppress(Exception):
self._cache_stack.remove(model)
self._cache_stack.append(model)
self._cache_stack.remove(key)
self._cache_stack.append(key)
return self.ModelLocker(self, key, model, gpu_load)
return self.ModelLocker(self, key, cache_entry.model, gpu_load)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load):
@ -604,7 +601,7 @@ class ModelCache(object):
if not hasattr(self.model, 'to'):
return self.model
cache_entry = self.cache._cached_infos[self.model]
cache_entry = self.cache._cached_models[self.key]
# NOTE that the model has to have the to() method in order for this
# code to move it into GPU!
@ -641,7 +638,7 @@ class ModelCache(object):
if not hasattr(self.model, 'to'):
return
cache_entry = self.cache._cached_infos[self.model]
cache_entry = self.cache._cached_models[self.key]
cache_entry.unlock()
if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models()
@ -667,7 +664,7 @@ class ModelCache(object):
def cache_size(self) -> float:
"Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_infos.values()])
current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG
def _has_cuda(self) -> bool:
@ -680,7 +677,7 @@ class ModelCache(object):
cached_models = 0
loaded_models = 0
locked_models = 0
for model_info in self._cached_infos.values():
for model_info in self._cached_models.values():
cached_models += 1
if model_info.loaded:
loaded_models += 1
@ -695,30 +692,32 @@ class ModelCache(object):
#multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_infos.values()])
current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}")
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model = self._cache_stack[pos]
model_info = self._cached_infos[model]
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(model)
refs = sys.getrefcount(cache_entry.model)
device = model.device if hasattr(model, "device") else None
self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}")
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
# 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function
if not model_info.locked and refs <= 3:
self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)')
current_size -= model_info.size
# 2 refs:
# 1 from cache_entry
# 1 from getrefcount function
if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size
del self._cache_stack[pos]
del model
del model_info
del self._cached_models[model_key]
del cache_entry
else:
pos += 1
@ -726,14 +725,14 @@ class ModelCache(object):
gc.collect()
torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}")
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self):
for model, model_info in self._cached_infos.items():
if not model_info.locked and model_info.loaded:
self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}')
model.to(self.storage_device)
for model_key, cache_entry in self._cached_models.items():
if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
cache_entry.model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256()