Revert weak references as can be done without it

This commit is contained in:
Sergey Borisov 2023-05-23 04:29:40 +03:00
parent 2533209326
commit 8e419a4f97

View File

@ -17,7 +17,6 @@ context. Use like this:
""" """
import contextlib import contextlib
import weakref
import gc import gc
import os import os
import sys import sys
@ -428,14 +427,14 @@ class ModelCache(object):
pass pass
class _CacheRecord: class _CacheRecord:
key: str
size: int size: int
model: Any
cache: ModelCache cache: ModelCache
_locks: int _locks: int
def __init__(self, cache, key: Any, size: int): def __init__(self, cache, model: Any, size: int):
self.key = key
self.size = size self.size = size
self.model = model
self.cache = cache self.cache = cache
self._locks = 0 self._locks = 0
@ -452,9 +451,8 @@ class _CacheRecord:
@property @property
def loaded(self): def loaded(self):
model = self.cache._cached_models.get(self.key, None) if self.model is not None and hasattr(self.model, "device"):
if model is not None and hasattr(model, "device"): return self.model.device != self.cache.storage_device
return model.device != self.cache.storage_device
else: else:
return False return False
@ -493,8 +491,7 @@ class ModelCache(object):
self.sha_chunksize=sha_chunksize self.sha_chunksize=sha_chunksize
self.logger = logger self.logger = logger
self._cached_models = weakref.WeakValueDictionary() self._cached_models = dict()
self._cached_infos = weakref.WeakKeyDictionary()
self._cache_stack = list() self._cache_stack = list()
def get_key( def get_key(
@ -570,8 +567,8 @@ class ModelCache(object):
) )
# TODO: lock for no copies on simultaneous calls? # TODO: lock for no copies on simultaneous calls?
model = self._cached_models.get(key, None) cache_entry = self._cached_models.get(key, None)
if model is None: if cache_entry is None:
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}') self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
# this will remove older cached models until # this will remove older cached models until
@ -584,14 +581,14 @@ class ModelCache(object):
if mem_used := model_info.get_size(submodel): if mem_used := model_info.get_size(submodel):
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB') self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
self._cached_models[key] = model cache_entry = _CacheRecord(self, model, mem_used)
self._cached_infos[model] = _CacheRecord(self, key, mem_used) self._cached_models[key] = cache_entry
with suppress(Exception): with suppress(Exception):
self._cache_stack.remove(model) self._cache_stack.remove(key)
self._cache_stack.append(model) self._cache_stack.append(key)
return self.ModelLocker(self, key, model, gpu_load) return self.ModelLocker(self, key, cache_entry.model, gpu_load)
class ModelLocker(object): class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load): def __init__(self, cache, key, model, gpu_load):
@ -604,7 +601,7 @@ class ModelCache(object):
if not hasattr(self.model, 'to'): if not hasattr(self.model, 'to'):
return self.model return self.model
cache_entry = self.cache._cached_infos[self.model] cache_entry = self.cache._cached_models[self.key]
# NOTE that the model has to have the to() method in order for this # NOTE that the model has to have the to() method in order for this
# code to move it into GPU! # code to move it into GPU!
@ -641,7 +638,7 @@ class ModelCache(object):
if not hasattr(self.model, 'to'): if not hasattr(self.model, 'to'):
return return
cache_entry = self.cache._cached_infos[self.model] cache_entry = self.cache._cached_models[self.key]
cache_entry.unlock() cache_entry.unlock()
if not self.cache.lazy_offloading: if not self.cache.lazy_offloading:
self.cache._offload_unlocked_models() self.cache._offload_unlocked_models()
@ -667,7 +664,7 @@ class ModelCache(object):
def cache_size(self) -> float: def cache_size(self) -> float:
"Return the current size of the cache, in GB" "Return the current size of the cache, in GB"
current_cache_size = sum([m.size for m in self._cached_infos.values()]) current_cache_size = sum([m.size for m in self._cached_models.values()])
return current_cache_size / GIG return current_cache_size / GIG
def _has_cuda(self) -> bool: def _has_cuda(self) -> bool:
@ -680,7 +677,7 @@ class ModelCache(object):
cached_models = 0 cached_models = 0
loaded_models = 0 loaded_models = 0
locked_models = 0 locked_models = 0
for model_info in self._cached_infos.values(): for model_info in self._cached_models.values():
cached_models += 1 cached_models += 1
if model_info.loaded: if model_info.loaded:
loaded_models += 1 loaded_models += 1
@ -695,30 +692,32 @@ class ModelCache(object):
#multiplier = 2 if self.precision==torch.float32 else 1 #multiplier = 2 if self.precision==torch.float32 else 1
bytes_needed = model_size bytes_needed = model_size
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
current_size = sum([m.size for m in self._cached_infos.values()]) current_size = sum([m.size for m in self._cached_models.values()])
if current_size + bytes_needed > maximum_size: if current_size + bytes_needed > maximum_size:
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB') self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}") self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0 pos = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model = self._cache_stack[pos] model_key = self._cache_stack[pos]
model_info = self._cached_infos[model] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(model) refs = sys.getrefcount(cache_entry.model)
device = model.device if hasattr(model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}") self.logger.debug(f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}")
# 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function # 2 refs:
if not model_info.locked and refs <= 3: # 1 from cache_entry
self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)') # 1 from getrefcount function
current_size -= model_info.size if not cache_entry.locked and refs <= 2:
self.logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
current_size -= cache_entry.size
del self._cache_stack[pos] del self._cache_stack[pos]
del model del self._cached_models[model_key]
del model_info del cache_entry
else: else:
pos += 1 pos += 1
@ -726,14 +725,14 @@ class ModelCache(object):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}") self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
def _offload_unlocked_models(self): def _offload_unlocked_models(self):
for model, model_info in self._cached_infos.items(): for model_key, cache_entry in self._cached_models.items():
if not model_info.locked and model_info.loaded: if not cache_entry.locked and cache_entry.loaded:
self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}') self.logger.debug(f'Offloading {model_key} from {self.execution_device} into {self.storage_device}')
model.to(self.storage_device) cache_entry.model.to(self.storage_device)
def _local_model_hash(self, model_path: Union[str, Path]) -> str: def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256() sha = hashlib.sha256()