mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
copy model from a meta device template
- temporarily disable vram cache
This commit is contained in:
parent
6932f27b43
commit
2219e3643a
@ -52,11 +52,10 @@ class CacheRecord(Generic[T]):
|
|||||||
Elements of the cache:
|
Elements of the cache:
|
||||||
|
|
||||||
key: Unique key for each model, same as used in the models database.
|
key: Unique key for each model, same as used in the models database.
|
||||||
model: Model in memory.
|
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||||
used as a template for creating a copy in the VRAM.
|
used as a template for creating a copy in the VRAM.
|
||||||
size: Size of the model
|
size: Size of the model
|
||||||
loaded: True if the model's state dict is currently in VRAM
|
|
||||||
|
|
||||||
Before a model is executed, the state_dict template is copied into VRAM,
|
Before a model is executed, the state_dict template is copied into VRAM,
|
||||||
and then injected into the model. When the model is finished, the VRAM
|
and then injected into the model. When the model is finished, the VRAM
|
||||||
@ -72,25 +71,7 @@ class CacheRecord(Generic[T]):
|
|||||||
key: str
|
key: str
|
||||||
size: int
|
size: int
|
||||||
model: T
|
model: T
|
||||||
device: torch.device
|
|
||||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||||
size: int
|
|
||||||
loaded: bool = False
|
|
||||||
_locks: int = 0
|
|
||||||
|
|
||||||
def lock(self) -> None:
|
|
||||||
"""Lock this record."""
|
|
||||||
self._locks += 1
|
|
||||||
|
|
||||||
def unlock(self) -> None:
|
|
||||||
"""Unlock this record."""
|
|
||||||
self._locks -= 1
|
|
||||||
assert self._locks >= 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def locked(self) -> bool:
|
|
||||||
"""Return true if record is locked."""
|
|
||||||
return self._locks > 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -36,6 +36,7 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da
|
|||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
|
from ..optimizations import skip_torch_weight_init
|
||||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||||
from .model_locker import ModelLocker
|
from .model_locker import ModelLocker
|
||||||
|
|
||||||
@ -221,8 +222,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
size = calc_model_size_by_data(model)
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
if isinstance(model, torch.nn.Module):
|
||||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
state_dict = model.state_dict() # keep a master copy of the state dict
|
||||||
|
model = model.to(device="meta") # and keep a template in the meta device
|
||||||
|
else:
|
||||||
|
state_dict = None
|
||||||
|
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@ -284,48 +289,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
return model_key
|
return model_key
|
||||||
|
|
||||||
def offload_unlocked_models(self, size_required: int) -> None:
|
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||||
"""Move any unused models from VRAM."""
|
"""Move a copy of the model into the indicated device and return it.
|
||||||
device = self.get_execution_device()
|
|
||||||
reserved = self._max_vram_cache_size * GIG
|
|
||||||
vram_in_use = torch.cuda.memory_allocated(device) + size_required
|
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
|
||||||
if vram_in_use <= reserved:
|
|
||||||
break
|
|
||||||
if not cache_entry.loaded:
|
|
||||||
continue
|
|
||||||
if cache_entry.device is not device:
|
|
||||||
continue
|
|
||||||
if not cache_entry.locked:
|
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
|
||||||
cache_entry.loaded = False
|
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
|
||||||
self.logger.debug(
|
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
TorchDevice.empty_cache()
|
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
|
||||||
"""Move model into the indicated device.
|
|
||||||
|
|
||||||
:param cache_entry: The CacheRecord for the model
|
:param cache_entry: The CacheRecord for the model
|
||||||
:param target_device: The torch.device to move the model into
|
:param target_device: The torch.device to move the model into
|
||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
"""
|
"""
|
||||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
self.logger.info(f"Called to move {cache_entry.key} to {target_device}")
|
||||||
source_device = cache_entry.device
|
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Some models don't have a state dictionary, in which case the
|
||||||
# This would need to be revised to support multi-GPU.
|
# stored model will still reside in CPU
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if cache_entry.state_dict is None:
|
||||||
return
|
return cache_entry.model
|
||||||
|
|
||||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
|
||||||
if not hasattr(cache_entry.model, "to"):
|
|
||||||
return
|
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
@ -338,27 +315,25 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
# in RAM into the model. So this operation is very fast.
|
# in RAM into the model. So this operation is very fast.
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cache_entry.state_dict is not None:
|
assert isinstance(cache_entry.model, torch.nn.Module)
|
||||||
assert hasattr(cache_entry.model, "load_state_dict")
|
template = cache_entry.model
|
||||||
if target_device == self.storage_device:
|
cls = template.__class__
|
||||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
with skip_torch_weight_init():
|
||||||
|
if hasattr(cls, "from_config"):
|
||||||
|
working_model = template.__class__.from_config(template.config) # diffusers style
|
||||||
else:
|
else:
|
||||||
new_dict: Dict[str, torch.Tensor] = {}
|
working_model = template.__class__(config=template.config) # transformers style (sigh)
|
||||||
for k, v in cache_entry.state_dict.items():
|
working_model.to(device=target_device, dtype=self._precision)
|
||||||
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
working_model.load_state_dict(cache_entry.state_dict)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
|
||||||
cache_entry.model.to(target_device, non_blocking=True)
|
|
||||||
cache_entry.device = target_device
|
|
||||||
except Exception as e: # blow away cache entry
|
except Exception as e: # blow away cache entry
|
||||||
self._delete_cache_entry(cache_entry)
|
self._delete_cache_entry(cache_entry)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
end_model_to_time = time.time()
|
end_model_to_time = time.time()
|
||||||
self.logger.debug(
|
self.logger.info(
|
||||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
f"Moved model '{cache_entry.key}' to"
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
@ -380,34 +355,21 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
abs_tol=10 * MB,
|
abs_tol=10 * MB,
|
||||||
):
|
):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
f"Moving model '{cache_entry.key}' from to"
|
||||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||||
" estimated size may be incorrect. Estimated model size:"
|
" estimated size may be incorrect. Estimated model size:"
|
||||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
return working_model
|
||||||
|
|
||||||
def print_cuda_stats(self) -> None:
|
def print_cuda_stats(self) -> None:
|
||||||
"""Log CUDA diagnostics."""
|
"""Log CUDA diagnostics."""
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||||
|
|
||||||
in_ram_models = 0
|
in_ram_models = len(self._cached_models)
|
||||||
in_vram_models = 0
|
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
|
||||||
locked_in_vram_models = 0
|
|
||||||
for cache_record in self._cached_models.values():
|
|
||||||
if hasattr(cache_record.model, "device"):
|
|
||||||
if cache_record.model.device == self.storage_device:
|
|
||||||
in_ram_models += 1
|
|
||||||
else:
|
|
||||||
in_vram_models += 1
|
|
||||||
if cache_record.locked:
|
|
||||||
locked_in_vram_models += 1
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
|
||||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_room(self, size: int) -> None:
|
def make_room(self, size: int) -> None:
|
||||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||||
@ -433,29 +395,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
refs = sys.getrefcount(cache_entry.model)
|
refs = sys.getrefcount(cache_entry.model)
|
||||||
|
|
||||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
|
||||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
|
||||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
|
||||||
|
|
||||||
# manualy clear local variable references of just finished function calls
|
|
||||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
|
||||||
if refs > 2:
|
|
||||||
while True:
|
|
||||||
cleared = False
|
|
||||||
for referrer in gc.get_referrers(cache_entry.model):
|
|
||||||
if type(referrer).__name__ == "frame":
|
|
||||||
# RuntimeError: cannot clear an executing frame
|
|
||||||
with suppress(RuntimeError):
|
|
||||||
referrer.clear()
|
|
||||||
cleared = True
|
|
||||||
# break
|
|
||||||
|
|
||||||
# repeat if referrers changes(due to frame clear), else exit loop
|
|
||||||
if cleared:
|
|
||||||
gc.collect()
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||||
|
@ -37,25 +37,22 @@ class ModelLocker(ModelLockerBase):
|
|||||||
|
|
||||||
def lock(self) -> AnyModel:
|
def lock(self) -> AnyModel:
|
||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
self._cache_entry.lock()
|
|
||||||
try:
|
try:
|
||||||
device = self._cache.get_execution_device()
|
device = self._cache.get_execution_device()
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
model_on_device = self._cache.model_to_device(self._cache_entry, device)
|
||||||
self._cache.move_model_to_device(self._cache_entry, device)
|
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
|
||||||
self._cache_entry.loaded = True
|
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}")
|
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||||
self._cache_entry.unlock()
|
|
||||||
raise
|
raise
|
||||||
except Exception:
|
except Exception:
|
||||||
self._cache_entry.unlock()
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return self.model
|
return model_on_device
|
||||||
|
|
||||||
|
# It is no longer necessary to move the model out of VRAM
|
||||||
|
# because it will be removed when it goes out of scope
|
||||||
|
# in the caller's context
|
||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
self._cache_entry.unlock()
|
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
@ -129,9 +129,7 @@ class ModelPatcher:
|
|||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
if module_key not in original_weights:
|
if module_key not in original_weights:
|
||||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
if model_state_dict is None: # no CPU copy of the state dict was provided
|
||||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
|
||||||
else:
|
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
@ -158,6 +156,9 @@ class ModelPatcher:
|
|||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
|
||||||
|
# Therefore it should not be necessary to copy the original model weights back.
|
||||||
|
# This needs to be fixed before resurrecting the VRAM cache.
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user