implement lazy GPU offloading and ref counting

This commit is contained in:
Lincoln Stein 2023-05-04 23:15:32 -04:00
parent a273bdbdc1
commit 68bc0112fa

View File

@ -21,6 +21,7 @@ import gc
import hashlib
import warnings
from collections.abc import Generator
from collections import Counter
from enum import Enum
from pathlib import Path
from typing import Sequence, Union
@ -36,6 +37,7 @@ from pydantic import BaseModel
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
from ..globals import global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from . import load_pipeline_from_original_stable_diffusion_ckpt
@ -77,6 +79,7 @@ class ModelCache(object):
storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16,
sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216,
):
'''
@ -84,17 +87,21 @@ class ModelCache(object):
:param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
self.models: dict = dict()
self.stack: Sequence = list()
self.lazy_offloading = lazy_offloading
self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision
self.max_models: int=max_models
self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize
self.loaded_models: set = set() # set of model keys loaded in GPU
self.locked_models: Counter = Counter() # set of model keys locked in GPU
@contextlib.contextmanager
def get_model(
@ -149,21 +156,39 @@ class ModelCache(object):
if submodel:
model = getattr(model, submodel.name)
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
else:
debugging_name = repo_id_or_path
try:
if gpu_load and hasattr(model,'to'):
print(f' | Loading {debugging_name} into GPU')
if gpu_load and hasattr(model,'to'):
try:
self.loaded_models.add(key)
self.locked_models[key] += 1
if self.lazy_offloading:
self._offload_unlocked_models()
logger.debug(f'Loading {key} into {self.execution_device}')
model.to(self.execution_device) # move into GPU
self._print_cuda_stats()
yield model
finally:
if gpu_load and hasattr(model,'to'):
print(f' | Unloading {debugging_name} from GPU')
model.to(self.storage_device)
yield model
finally:
self.locked_models[key] -= 1
if not self.lazy_offloading:
self._offload_unlocked_models()
self._print_cuda_stats()
else:
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
if hasattr(model,'to') and (key in self.loaded_models
and self.locked_models[key] == 0):
model.to(self.storage_device)
yield model
def _offload_unlocked_models(self):
to_offload = set()
for key in self.loaded_models:
if key not in self.locked_models or self.locked_models[key] == 0:
logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
to_offload.add(key)
for key in to_offload:
self.models[key].to(self.storage_device)
self.loaded_models.remove(key)
def model_hash(self,
repo_id_or_path: Union[str,Path],
@ -193,16 +218,16 @@ class ModelCache(object):
return self.execution_device.type == 'cuda'
def _print_cuda_stats(self):
print(
" | Current VRAM usage:",
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
)
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
loaded_models = len(self.loaded_models)
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
def _make_cache_room(self):
models_in_ram = len(self.models)
while models_in_ram >= self.max_models:
if least_recently_used_key := self.stack.pop(0):
print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
del self.models[least_recently_used_key]
models_in_ram = len(self.models)
gc.collect()
@ -291,7 +316,7 @@ class ModelCache(object):
and option to exit if an infected file is identified.
"""
# scan model
print(f" | Scanning Model: {model_name}")
logger.debug(f"Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0:
if scan_result.infected_files == 1:
@ -299,7 +324,7 @@ class ModelCache(object):
else:
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
else:
print(" | Model scanned ok")
logger.debug("Model scanned ok")
def _load_ckpt_from_storage(self,
ckpt_path: Union[str,Path],
@ -330,7 +355,7 @@ class ModelCache(object):
hash = f.read()
return hash
print(f' | computing hash of model {path.name}')
logger.debug(f'computing hash of model {path.name}')
with open(path, "rb") as f:
while chunk := f.read(self.sha_chunksize):
sha.update(chunk)
@ -350,7 +375,7 @@ class ModelCache(object):
hash = f.read()
return hash
print(f' | computing hash of model {path.name}')
logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")):