diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 2eb675c3e3..1af02637d3 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -21,6 +21,7 @@ import gc import hashlib import warnings from collections.abc import Generator +from collections import Counter from enum import Enum from pathlib import Path from typing import Sequence, Union @@ -36,6 +37,7 @@ from pydantic import BaseModel from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import logging as transformers_logging +import invokeai.backend.util.logging as logger from ..globals import global_cache_dir from ..stable_diffusion import StableDiffusionGeneratorPipeline from . import load_pipeline_from_original_stable_diffusion_ckpt @@ -77,6 +79,7 @@ class ModelCache(object): storage_device: torch.device=torch.device('cpu'), precision: torch.dtype=torch.float16, sequential_offload: bool=False, + lazy_offloading: bool=True, sha_chunksize: int = 16777216, ): ''' @@ -84,17 +87,21 @@ class ModelCache(object): :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] + :param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sha_chunksize: Chunksize to use when calculating sha256 model hash ''' self.models: dict = dict() self.stack: Sequence = list() + self.lazy_offloading = lazy_offloading self.sequential_offload: bool=sequential_offload self.precision: torch.dtype=precision self.max_models: int=max_models self.execution_device: torch.device=execution_device self.storage_device: torch.device=storage_device self.sha_chunksize=sha_chunksize + self.loaded_models: set = set() # set of model keys loaded in GPU + self.locked_models: Counter = Counter() # set of model keys locked in GPU @contextlib.contextmanager def get_model( @@ -149,21 +156,39 @@ class ModelCache(object): if submodel: model = getattr(model, submodel.name) - debugging_name = f'{submodel.name} submodel of {repo_id_or_path}' - else: - debugging_name = repo_id_or_path - - try: - if gpu_load and hasattr(model,'to'): - print(f' | Loading {debugging_name} into GPU') + + if gpu_load and hasattr(model,'to'): + try: + self.loaded_models.add(key) + self.locked_models[key] += 1 + if self.lazy_offloading: + self._offload_unlocked_models() + logger.debug(f'Loading {key} into {self.execution_device}') model.to(self.execution_device) # move into GPU self._print_cuda_stats() - yield model - finally: - if gpu_load and hasattr(model,'to'): - print(f' | Unloading {debugging_name} from GPU') - model.to(self.storage_device) + yield model + finally: + self.locked_models[key] -= 1 + if not self.lazy_offloading: + self._offload_unlocked_models() self._print_cuda_stats() + else: + # in the event that the caller wants the model in RAM, we + # move it into CPU if it is in GPU and not locked + if hasattr(model,'to') and (key in self.loaded_models + and self.locked_models[key] == 0): + model.to(self.storage_device) + yield model + + def _offload_unlocked_models(self): + to_offload = set() + for key in self.loaded_models: + if key not in self.locked_models or self.locked_models[key] == 0: + logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}') + to_offload.add(key) + for key in to_offload: + self.models[key].to(self.storage_device) + self.loaded_models.remove(key) def model_hash(self, repo_id_or_path: Union[str,Path], @@ -193,16 +218,16 @@ class ModelCache(object): return self.execution_device.type == 'cuda' def _print_cuda_stats(self): - print( - " | Current VRAM usage:", - "%4.2fG" % (torch.cuda.memory_allocated() / 1e9), - ) + vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9) + loaded_models = len(self.loaded_models) + locked_models = len([x for x in self.locked_models if self.locked_models[x]>0]) + logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}") def _make_cache_room(self): models_in_ram = len(self.models) while models_in_ram >= self.max_models: if least_recently_used_key := self.stack.pop(0): - print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') + logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') del self.models[least_recently_used_key] models_in_ram = len(self.models) gc.collect() @@ -291,7 +316,7 @@ class ModelCache(object): and option to exit if an infected file is identified. """ # scan model - print(f" | Scanning Model: {model_name}") + logger.debug(f"Scanning Model: {model_name}") scan_result = scan_file_path(checkpoint) if scan_result.infected_files != 0: if scan_result.infected_files == 1: @@ -299,7 +324,7 @@ class ModelCache(object): else: raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting") else: - print(" | Model scanned ok") + logger.debug("Model scanned ok") def _load_ckpt_from_storage(self, ckpt_path: Union[str,Path], @@ -330,7 +355,7 @@ class ModelCache(object): hash = f.read() return hash - print(f' | computing hash of model {path.name}') + logger.debug(f'computing hash of model {path.name}') with open(path, "rb") as f: while chunk := f.read(self.sha_chunksize): sha.update(chunk) @@ -350,7 +375,7 @@ class ModelCache(object): hash = f.read() return hash - print(f' | computing hash of model {path.name}') + logger.debug(f'computing hash of model {path.name}') for file in list(path.rglob("*.ckpt")) \ + list(path.rglob("*.safetensors")) \ + list(path.rglob("*.pth")):