mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main and resolve conflicts
This commit is contained in:
@ -42,10 +42,26 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""Elements of the cache."""
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
@ -20,7 +20,6 @@ context. Use like this:
|
||||
|
||||
import gc
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
@ -163,7 +162,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@ -253,21 +254,40 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
model = cache_entry.model
|
||||
source_device = cache_entry.device
|
||||
|
||||
source_device = model.device if hasattr(model, "device") else self.storage_device
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if hasattr(model, "to"):
|
||||
model.to(target_device)
|
||||
elif isinstance(model, dict):
|
||||
for _, v in model.items():
|
||||
if hasattr(v, "to"):
|
||||
v.to(target_device)
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
@ -347,43 +367,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
if not cache_entry.locked:
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
|
Reference in New Issue
Block a user