mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rewrite cache to weak references
This commit is contained in:
parent
165c1adcf8
commit
2533209326
@ -267,10 +267,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
logger.debug(f'config file={config_file}')
|
logger.debug(f'config file={config_file}')
|
||||||
|
|
||||||
device = torch.device(choose_torch_device())
|
device = torch.device(choose_torch_device())
|
||||||
if config.precision == "auto":
|
precision = config.precision
|
||||||
|
if precision == "auto":
|
||||||
precision = choose_precision(device)
|
precision = choose_precision(device)
|
||||||
dtype = torch.float32 if precision=='float32' \
|
dtype = torch.float32 if precision == 'float32' else torch.float16
|
||||||
else torch.float16
|
|
||||||
|
|
||||||
# this is transitional backward compatibility
|
# this is transitional backward compatibility
|
||||||
# support for the deprecated `max_loaded_models`
|
# support for the deprecated `max_loaded_models`
|
||||||
|
@ -17,6 +17,7 @@ context. Use like this:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@ -427,15 +428,15 @@ class ModelCache(object):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class _CacheRecord:
|
class _CacheRecord:
|
||||||
model: Any
|
key: str
|
||||||
size: int
|
size: int
|
||||||
|
cache: ModelCache
|
||||||
_locks: int
|
_locks: int
|
||||||
_cache: ModelCache
|
|
||||||
|
|
||||||
def __init__(self, cache, model: Any, size: int):
|
def __init__(self, cache, key: Any, size: int):
|
||||||
self._cache = cache
|
self.key = key
|
||||||
self.model = model
|
|
||||||
self.size = size
|
self.size = size
|
||||||
|
self.cache = cache
|
||||||
self._locks = 0
|
self._locks = 0
|
||||||
|
|
||||||
def lock(self):
|
def lock(self):
|
||||||
@ -451,8 +452,9 @@ class _CacheRecord:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def loaded(self):
|
def loaded(self):
|
||||||
if self.model is not None and hasattr(self.model, "device"):
|
model = self.cache._cached_models.get(self.key, None)
|
||||||
return self.model.device != self._cache.storage_device
|
if model is not None and hasattr(model, "device"):
|
||||||
|
return model.device != self.cache.storage_device
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -478,22 +480,23 @@ class ModelCache(object):
|
|||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||||
'''
|
'''
|
||||||
max_cache_size = 9999
|
#max_cache_size = 9999
|
||||||
execution_device = torch.device('cuda')
|
execution_device = torch.device('cuda')
|
||||||
|
|
||||||
self.models: Dict[str, _CacheRecord] = dict()
|
|
||||||
self.model_infos: Dict[str, ModelInfoBase] = dict()
|
self.model_infos: Dict[str, ModelInfoBase] = dict()
|
||||||
self.stack: Sequence = list()
|
|
||||||
self.lazy_offloading = lazy_offloading
|
self.lazy_offloading = lazy_offloading
|
||||||
self.sequential_offload: bool=sequential_offload
|
#self.sequential_offload: bool=sequential_offload
|
||||||
self.precision: torch.dtype=precision
|
self.precision: torch.dtype=precision
|
||||||
self.current_cache_size: int=0
|
|
||||||
self.max_cache_size: int=max_cache_size
|
self.max_cache_size: int=max_cache_size
|
||||||
self.execution_device: torch.device=execution_device
|
self.execution_device: torch.device=execution_device
|
||||||
self.storage_device: torch.device=storage_device
|
self.storage_device: torch.device=storage_device
|
||||||
self.sha_chunksize=sha_chunksize
|
self.sha_chunksize=sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
self._cached_models = weakref.WeakValueDictionary()
|
||||||
|
self._cached_infos = weakref.WeakKeyDictionary()
|
||||||
|
self._cache_stack = list()
|
||||||
|
|
||||||
def get_key(
|
def get_key(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
@ -546,8 +549,9 @@ class ModelCache(object):
|
|||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str, Path],
|
repo_id_or_path: Union[str, Path],
|
||||||
model_type: SDModelType = SDModelType.Diffusers,
|
model_type: SDModelType = SDModelType.Diffusers,
|
||||||
submodel: SDModelType = None,
|
submodel: Optional[SDModelType] = None,
|
||||||
revision: str = None,
|
revision: Optional[str] = None,
|
||||||
|
variant: Optional[str] = None,
|
||||||
gpu_load: bool = True,
|
gpu_load: bool = True,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
|
|
||||||
@ -565,7 +569,9 @@ class ModelCache(object):
|
|||||||
submodel_type=submodel,
|
submodel_type=submodel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if key not in self.models:
|
# TODO: lock for no copies on simultaneous calls?
|
||||||
|
model = self._cached_models.get(key, None)
|
||||||
|
if model is None:
|
||||||
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
|
self.logger.info(f'Loading model {repo_id_or_path}, type {model_type}:{submodel}')
|
||||||
|
|
||||||
# this will remove older cached models until
|
# this will remove older cached models until
|
||||||
@ -574,45 +580,38 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# clean memory to make MemoryUsage() more accurate
|
# clean memory to make MemoryUsage() more accurate
|
||||||
gc.collect()
|
gc.collect()
|
||||||
model_obj = model_info.get_model(submodel, torch_dtype=self.precision)
|
model = model_info.get_model(submodel, torch_dtype=self.precision)
|
||||||
if mem_used := model_info.get_size(submodel):
|
if mem_used := model_info.get_size(submodel):
|
||||||
logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
self.logger.debug(f'CPU RAM used for load: {(mem_used/GIG):.2f} GB')
|
||||||
self.current_cache_size += mem_used # increment size of the cache
|
|
||||||
|
|
||||||
self.models[key] = _CacheRecord(self, model_obj, mem_used)
|
self._cached_models[key] = model
|
||||||
|
self._cached_infos[model] = _CacheRecord(self, key, mem_used)
|
||||||
|
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
self.stack.remove(key)
|
self._cache_stack.remove(model)
|
||||||
self.stack.append(key)
|
self._cache_stack.append(model)
|
||||||
|
|
||||||
return self.ModelLocker(self, key, self.models[key].model, gpu_load)
|
return self.ModelLocker(self, key, model, gpu_load)
|
||||||
|
|
||||||
def uncache_model(self, key: str):
|
|
||||||
'''Remove corresponding model from the cache'''
|
|
||||||
self.models.pop(key, None)
|
|
||||||
with contextlib.suppress(ValueError):
|
|
||||||
self.stack.remove(key)
|
|
||||||
|
|
||||||
class ModelLocker(object):
|
class ModelLocker(object):
|
||||||
def __init__(self, cache, key, model, gpu_load):
|
def __init__(self, cache, key, model, gpu_load):
|
||||||
self.gpu_load = gpu_load
|
self.gpu_load = gpu_load
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
self.key = key
|
self.key = key
|
||||||
# This will keep a copy of the model in RAM until the locker
|
|
||||||
# is garbage collected. Needs testing!
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def __enter__(self) -> Any:
|
def __enter__(self) -> Any:
|
||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, 'to'):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
cache_entry = self.cache.models[self.key]
|
cache_entry = self.cache._cached_infos[self.model]
|
||||||
|
|
||||||
# NOTE that the model has to have the to() method in order for this
|
# NOTE that the model has to have the to() method in order for this
|
||||||
# code to move it into GPU!
|
# code to move it into GPU!
|
||||||
if self.gpu_load:
|
if self.gpu_load:
|
||||||
cache_entry.lock()
|
cache_entry.lock()
|
||||||
|
|
||||||
|
try:
|
||||||
if self.cache.lazy_offloading:
|
if self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models()
|
self.cache._offload_unlocked_models()
|
||||||
|
|
||||||
@ -625,6 +624,11 @@ class ModelCache(object):
|
|||||||
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
self.cache.logger.debug(f'Locking {self.key} in {self.cache.execution_device}')
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
|
except:
|
||||||
|
cache_entry.unlock()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
# TODO: not fully understand
|
# TODO: not fully understand
|
||||||
# in the event that the caller wants the model in RAM, we
|
# in the event that the caller wants the model in RAM, we
|
||||||
# move it into CPU if it is in GPU and not locked
|
# move it into CPU if it is in GPU and not locked
|
||||||
@ -637,11 +641,13 @@ class ModelCache(object):
|
|||||||
if not hasattr(self.model, 'to'):
|
if not hasattr(self.model, 'to'):
|
||||||
return
|
return
|
||||||
|
|
||||||
self.cache.models[self.key].unlock()
|
cache_entry = self.cache._cached_infos[self.model]
|
||||||
|
cache_entry.unlock()
|
||||||
if not self.cache.lazy_offloading:
|
if not self.cache.lazy_offloading:
|
||||||
self.cache._offload_unlocked_models()
|
self.cache._offload_unlocked_models()
|
||||||
self.cache._print_cuda_stats()
|
self.cache._print_cuda_stats()
|
||||||
|
|
||||||
|
|
||||||
def model_hash(
|
def model_hash(
|
||||||
self,
|
self,
|
||||||
repo_id_or_path: Union[str, Path],
|
repo_id_or_path: Union[str, Path],
|
||||||
@ -661,55 +667,73 @@ class ModelCache(object):
|
|||||||
|
|
||||||
def cache_size(self) -> float:
|
def cache_size(self) -> float:
|
||||||
"Return the current size of the cache, in GB"
|
"Return the current size of the cache, in GB"
|
||||||
return self.current_cache_size / GIG
|
current_cache_size = sum([m.size for m in self._cached_infos.values()])
|
||||||
|
return current_cache_size / GIG
|
||||||
|
|
||||||
def _has_cuda(self) -> bool:
|
def _has_cuda(self) -> bool:
|
||||||
return self.execution_device.type == 'cuda'
|
return self.execution_device.type == 'cuda'
|
||||||
|
|
||||||
def _print_cuda_stats(self):
|
def _print_cuda_stats(self):
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
ram = "%4.2fG" % (self.current_cache_size / GIG)
|
ram = "%4.2fG" % self.cache_size()
|
||||||
|
|
||||||
|
cached_models = 0
|
||||||
loaded_models = 0
|
loaded_models = 0
|
||||||
locked_models = 0
|
locked_models = 0
|
||||||
for cache_entry in self.models.values():
|
for model_info in self._cached_infos.values():
|
||||||
if cache_entry.loaded:
|
cached_models += 1
|
||||||
|
if model_info.loaded:
|
||||||
loaded_models += 1
|
loaded_models += 1
|
||||||
if cache_entry.locked:
|
if model_info.locked:
|
||||||
locked_models += 1
|
locked_models += 1
|
||||||
|
|
||||||
logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
self.logger.debug(f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}")
|
||||||
|
|
||||||
|
|
||||||
def _make_cache_room(self, model_size):
|
def _make_cache_room(self, model_size):
|
||||||
# calculate how much memory this model will require
|
# calculate how much memory this model will require
|
||||||
#multiplier = 2 if self.precision==torch.float32 else 1
|
#multiplier = 2 if self.precision==torch.float32 else 1
|
||||||
bytes_needed = model_size
|
bytes_needed = model_size
|
||||||
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
maximum_size = self.max_cache_size * GIG # stored in GB, convert to bytes
|
||||||
current_size = self.current_cache_size
|
current_size = sum([m.size for m in self._cached_infos.values()])
|
||||||
|
|
||||||
if current_size + bytes_needed > maximum_size:
|
if current_size + bytes_needed > maximum_size:
|
||||||
logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
self.logger.debug(f'Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB')
|
||||||
|
|
||||||
|
self.logger.debug(f"Before unloading: cached_models={len(self._cached_infos)}")
|
||||||
|
|
||||||
pos = 0
|
pos = 0
|
||||||
while current_size + bytes_needed > maximum_size and current_size > 0 and len(self.stack) > 0 and pos < len(self.stack):
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||||
model_key = self.stack[pos]
|
model = self._cache_stack[pos]
|
||||||
cache_entry = self.models[model_key]
|
model_info = self._cached_infos[model]
|
||||||
if not cache_entry.locked:
|
|
||||||
logger.debug(f'Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)')
|
refs = sys.getrefcount(model)
|
||||||
self.uncache_model(model_key) # del self.stack[pos]
|
|
||||||
current_size -= cache_entry.size
|
device = model.device if hasattr(model, "device") else None
|
||||||
|
self.logger.debug(f"Model: {model_info.key}, locks: {model_info._locks}, device: {device}, loaded: {model_info.loaded}, refs: {refs}")
|
||||||
|
|
||||||
|
# 3 refs = 1 from _cache_stack, 1 from local variable, 1 from getrefcount function
|
||||||
|
if not model_info.locked and refs <= 3:
|
||||||
|
self.logger.debug(f'Unloading model {model_info.key} to free {(model_size/GIG):.2f} GB (-{(model_info.size/GIG):.2f} GB)')
|
||||||
|
current_size -= model_info.size
|
||||||
|
del self._cache_stack[pos]
|
||||||
|
del model
|
||||||
|
del model_info
|
||||||
|
|
||||||
else:
|
else:
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
self.current_cache_size = current_size
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
self.logger.debug(f"After unloading: cached_models={len(self._cached_infos)}")
|
||||||
|
|
||||||
|
|
||||||
def _offload_unlocked_models(self):
|
def _offload_unlocked_models(self):
|
||||||
for key in self.models.keys():
|
for model, model_info in self._cached_infos.items():
|
||||||
cache_entry = self.models[key]
|
if not model_info.locked and model_info.loaded:
|
||||||
if not cache_entry.locked and cache_entry.loaded:
|
self.logger.debug(f'Offloading {model_info.key} from {self.execution_device} into {self.storage_device}')
|
||||||
self.logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
model.to(self.storage_device)
|
||||||
cache_entry.model.to(self.storage_device)
|
|
||||||
|
|
||||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
@ -721,7 +745,7 @@ class ModelCache(object):
|
|||||||
hash = f.read()
|
hash = f.read()
|
||||||
return hash
|
return hash
|
||||||
|
|
||||||
logger.debug(f'computing hash of model {path.name}')
|
self.logger.debug(f'computing hash of model {path.name}')
|
||||||
for file in list(path.rglob("*.ckpt")) \
|
for file in list(path.rglob("*.ckpt")) \
|
||||||
+ list(path.rglob("*.safetensors")) \
|
+ list(path.rglob("*.safetensors")) \
|
||||||
+ list(path.rglob("*.pth")):
|
+ list(path.rglob("*.pth")):
|
||||||
|
@ -24,7 +24,7 @@ class CodeFormerRestoration:
|
|||||||
self.codeformer_model_exists = self.model_path.exists()
|
self.codeformer_model_exists = self.model_path.exists()
|
||||||
|
|
||||||
if not self.codeformer_model_exists:
|
if not self.codeformer_model_exists:
|
||||||
logger.error("NOT FOUND: CodeFormer model not found at " + self.model_path)
|
logger.error(f"NOT FOUND: CodeFormer model not found at {self.model_path}")
|
||||||
sys.path.append(os.path.abspath(codeformer_dir))
|
sys.path.append(os.path.abspath(codeformer_dir))
|
||||||
|
|
||||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||||
|
@ -18,7 +18,7 @@ class GFPGAN:
|
|||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
logger.error("NOT FOUND: GFPGAN model not found at " + self.model_path)
|
logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user