From 2219e3643a5392f778afeff8194d758d3fb96ae5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 24 Jun 2024 10:55:15 -0400 Subject: [PATCH] copy model from a meta device template - temporarily disable vram cache --- .../load/model_cache/model_cache_base.py | 21 +--- .../load/model_cache/model_cache_default.py | 119 +++++------------- .../load/model_cache/model_locker.py | 15 +-- invokeai/backend/model_patcher.py | 7 +- 4 files changed, 40 insertions(+), 122 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index b3e4e3ac12..9f9dd26362 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -52,11 +52,10 @@ class CacheRecord(Generic[T]): Elements of the cache: key: Unique key for each model, same as used in the models database. - model: Model in memory. + model: Read-only copy of the model *without weights* residing in the "meta device" state_dict: A read-only copy of the model's state dict in RAM. It will be used as a template for creating a copy in the VRAM. size: Size of the model - loaded: True if the model's state dict is currently in VRAM Before a model is executed, the state_dict template is copied into VRAM, and then injected into the model. When the model is finished, the VRAM @@ -72,25 +71,7 @@ class CacheRecord(Generic[T]): key: str size: int model: T - device: torch.device state_dict: Optional[Dict[str, torch.Tensor]] - size: int - loaded: bool = False - _locks: int = 0 - - def lock(self) -> None: - """Lock this record.""" - self._locks += 1 - - def unlock(self) -> None: - """Unlock this record.""" - self._locks -= 1 - assert self._locks >= 0 - - @property - def locked(self) -> bool: - """Return true if record is locked.""" - return self._locks > 0 @dataclass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index a00c8fcb87..6357ada241 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -36,6 +36,7 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_da from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger +from ..optimizations import skip_torch_weight_init from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_locker import ModelLocker @@ -221,8 +222,12 @@ class ModelCache(ModelCacheBase[AnyModel]): size = calc_model_size_by_data(model) self.make_room(size) - state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None - cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) + if isinstance(model, torch.nn.Module): + state_dict = model.state_dict() # keep a master copy of the state dict + model = model.to(device="meta") # and keep a template in the meta device + else: + state_dict = None + cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -284,48 +289,20 @@ class ModelCache(ModelCacheBase[AnyModel]): else: return model_key - def offload_unlocked_models(self, size_required: int) -> None: - """Move any unused models from VRAM.""" - device = self.get_execution_device() - reserved = self._max_vram_cache_size * GIG - vram_in_use = torch.cuda.memory_allocated(device) + size_required - self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") - for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: - break - if not cache_entry.loaded: - continue - if cache_entry.device is not device: - continue - if not cache_entry.locked: - self.move_model_to_device(cache_entry, self.storage_device) - cache_entry.loaded = False - vram_in_use = torch.cuda.memory_allocated() + size_required - self.logger.debug( - f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" - ) - - TorchDevice.empty_cache() - - def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: - """Move model into the indicated device. + def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel: + """Move a copy of the model into the indicated device and return it. :param cache_entry: The CacheRecord for the model :param target_device: The torch.device to move the model into May raise a torch.cuda.OutOfMemoryError """ - self.logger.debug(f"Called to move {cache_entry.key} to {target_device}") - source_device = cache_entry.device + self.logger.info(f"Called to move {cache_entry.key} to {target_device}") - # Note: We compare device types only so that 'cuda' == 'cuda:0'. - # This would need to be revised to support multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - # Some models don't have a `to` method, in which case they run in RAM/CPU. - if not hasattr(cache_entry.model, "to"): - return + # Some models don't have a state dictionary, in which case the + # stored model will still reside in CPU + if cache_entry.state_dict is None: + return cache_entry.model # This roundabout method for moving the model around is done to avoid # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. @@ -338,27 +315,25 @@ class ModelCache(ModelCacheBase[AnyModel]): # in RAM into the model. So this operation is very fast. start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() - try: - if cache_entry.state_dict is not None: - assert hasattr(cache_entry.model, "load_state_dict") - if target_device == self.storage_device: - cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) + assert isinstance(cache_entry.model, torch.nn.Module) + template = cache_entry.model + cls = template.__class__ + with skip_torch_weight_init(): + if hasattr(cls, "from_config"): + working_model = template.__class__.from_config(template.config) # diffusers style else: - new_dict: Dict[str, torch.Tensor] = {} - for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True) - cache_entry.model.load_state_dict(new_dict, assign=True) - cache_entry.model.to(target_device, non_blocking=True) - cache_entry.device = target_device + working_model = template.__class__(config=template.config) # transformers style (sigh) + working_model.to(device=target_device, dtype=self._precision) + working_model.load_state_dict(cache_entry.state_dict) except Exception as e: # blow away cache entry self._delete_cache_entry(cache_entry) raise e snapshot_after = self._capture_memory_snapshot() end_model_to_time = time.time() - self.logger.debug( - f"Moved model '{cache_entry.key}' from {source_device} to" + self.logger.info( + f"Moved model '{cache_entry.key}' to" f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" @@ -380,34 +355,21 @@ class ModelCache(ModelCacheBase[AnyModel]): abs_tol=10 * MB, ): self.logger.debug( - f"Moving model '{cache_entry.key}' from {source_device} to" + f"Moving model '{cache_entry.key}' from to" f" {target_device} caused an unexpected change in VRAM usage. The model's" " estimated size may be incorrect. Estimated model size:" f" {(cache_entry.size/GIG):.3f} GB.\n" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" ) + return working_model def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) ram = "%4.2fG" % (self.cache_size() / GIG) - in_ram_models = 0 - in_vram_models = 0 - locked_in_vram_models = 0 - for cache_record in self._cached_models.values(): - if hasattr(cache_record.model, "device"): - if cache_record.model.device == self.storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 - - self.logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" - ) + in_ram_models = len(self._cached_models) + self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}") def make_room(self, size: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size.""" @@ -433,29 +395,6 @@ class ModelCache(ModelCacheBase[AnyModel]): refs = sys.getrefcount(cache_entry.model) - # HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly - # going against the advice in the Python docs by using `gc.get_referrers(...)` in this way: - # https://docs.python.org/3/library/gc.html#gc.get_referrers - - # manualy clear local variable references of just finished function calls - # for some reason python don't want to collect it even by gc.collect() immidiately - if refs > 2: - while True: - cleared = False - for referrer in gc.get_referrers(cache_entry.model): - if type(referrer).__name__ == "frame": - # RuntimeError: cannot clear an executing frame - with suppress(RuntimeError): - referrer.clear() - cleared = True - # break - - # repeat if referrers changes(due to frame clear), else exit loop - if cleared: - gc.collect() - else: - break - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None self.logger.debug( f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 9f9c05bce5..815fd41f04 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -37,25 +37,22 @@ class ModelLocker(ModelLockerBase): def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" - self._cache_entry.lock() try: device = self._cache.get_execution_device() - self._cache.offload_unlocked_models(self._cache_entry.size) - self._cache.move_model_to_device(self._cache_entry, device) - self._cache_entry.loaded = True - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {device}") + model_on_device = self._cache.model_to_device(self._cache_entry, device) + self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}") self._cache.print_cuda_stats() except torch.cuda.OutOfMemoryError: self._cache.logger.warning("Insufficient GPU memory to load model. Aborting") - self._cache_entry.unlock() raise except Exception: - self._cache_entry.unlock() raise - return self.model + return model_on_device + # It is no longer necessary to move the model out of VRAM + # because it will be removed when it goes out of scope + # in the caller's context def unlock(self) -> None: """Call upon exit from context.""" - self._cache_entry.unlock() self._cache.print_cuda_stats() diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index fdc79539ae..0f57c0efdc 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -129,9 +129,7 @@ class ModelPatcher: dtype = module.weight.dtype if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: + if model_state_dict is None: # no CPU copy of the state dict was provided original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 @@ -158,6 +156,9 @@ class ModelPatcher: yield # wait for context manager exit finally: + # LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed. + # Therefore it should not be necessary to copy the original model weights back. + # This needs to be fixed before resurrecting the VRAM cache. assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items():