implement lazy GPU offloading and ref counting

This commit is contained in:
Lincoln Stein 2023-05-04 23:15:32 -04:00
parent a273bdbdc1
commit 68bc0112fa

View File

@ -21,6 +21,7 @@ import gc
import hashlib import hashlib
import warnings import warnings
from collections.abc import Generator from collections.abc import Generator
from collections import Counter
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Sequence, Union from typing import Sequence, Union
@ -36,6 +37,7 @@ from pydantic import BaseModel
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
import invokeai.backend.util.logging as logger
from ..globals import global_cache_dir from ..globals import global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline from ..stable_diffusion import StableDiffusionGeneratorPipeline
from . import load_pipeline_from_original_stable_diffusion_ckpt from . import load_pipeline_from_original_stable_diffusion_ckpt
@ -77,6 +79,7 @@ class ModelCache(object):
storage_device: torch.device=torch.device('cpu'), storage_device: torch.device=torch.device('cpu'),
precision: torch.dtype=torch.float16, precision: torch.dtype=torch.float16,
sequential_offload: bool=False, sequential_offload: bool=False,
lazy_offloading: bool=True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
): ):
''' '''
@ -84,17 +87,21 @@ class ModelCache(object):
:param execution_device: Torch device to load active model into [torch.device('cuda')] :param execution_device: Torch device to load active model into [torch.device('cuda')]
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
''' '''
self.models: dict = dict() self.models: dict = dict()
self.stack: Sequence = list() self.stack: Sequence = list()
self.lazy_offloading = lazy_offloading
self.sequential_offload: bool=sequential_offload self.sequential_offload: bool=sequential_offload
self.precision: torch.dtype=precision self.precision: torch.dtype=precision
self.max_models: int=max_models self.max_models: int=max_models
self.execution_device: torch.device=execution_device self.execution_device: torch.device=execution_device
self.storage_device: torch.device=storage_device self.storage_device: torch.device=storage_device
self.sha_chunksize=sha_chunksize self.sha_chunksize=sha_chunksize
self.loaded_models: set = set() # set of model keys loaded in GPU
self.locked_models: Counter = Counter() # set of model keys locked in GPU
@contextlib.contextmanager @contextlib.contextmanager
def get_model( def get_model(
@ -149,21 +156,39 @@ class ModelCache(object):
if submodel: if submodel:
model = getattr(model, submodel.name) model = getattr(model, submodel.name)
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
else: if gpu_load and hasattr(model,'to'):
debugging_name = repo_id_or_path try:
self.loaded_models.add(key)
try: self.locked_models[key] += 1
if gpu_load and hasattr(model,'to'): if self.lazy_offloading:
print(f' | Loading {debugging_name} into GPU') self._offload_unlocked_models()
logger.debug(f'Loading {key} into {self.execution_device}')
model.to(self.execution_device) # move into GPU model.to(self.execution_device) # move into GPU
self._print_cuda_stats() self._print_cuda_stats()
yield model yield model
finally: finally:
if gpu_load and hasattr(model,'to'): self.locked_models[key] -= 1
print(f' | Unloading {debugging_name} from GPU') if not self.lazy_offloading:
model.to(self.storage_device) self._offload_unlocked_models()
self._print_cuda_stats() self._print_cuda_stats()
else:
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
if hasattr(model,'to') and (key in self.loaded_models
and self.locked_models[key] == 0):
model.to(self.storage_device)
yield model
def _offload_unlocked_models(self):
to_offload = set()
for key in self.loaded_models:
if key not in self.locked_models or self.locked_models[key] == 0:
logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
to_offload.add(key)
for key in to_offload:
self.models[key].to(self.storage_device)
self.loaded_models.remove(key)
def model_hash(self, def model_hash(self,
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str,Path],
@ -193,16 +218,16 @@ class ModelCache(object):
return self.execution_device.type == 'cuda' return self.execution_device.type == 'cuda'
def _print_cuda_stats(self): def _print_cuda_stats(self):
print( vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
" | Current VRAM usage:", loaded_models = len(self.loaded_models)
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9), locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
) logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
def _make_cache_room(self): def _make_cache_room(self):
models_in_ram = len(self.models) models_in_ram = len(self.models)
while models_in_ram >= self.max_models: while models_in_ram >= self.max_models:
if least_recently_used_key := self.stack.pop(0): if least_recently_used_key := self.stack.pop(0):
print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}') logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
del self.models[least_recently_used_key] del self.models[least_recently_used_key]
models_in_ram = len(self.models) models_in_ram = len(self.models)
gc.collect() gc.collect()
@ -291,7 +316,7 @@ class ModelCache(object):
and option to exit if an infected file is identified. and option to exit if an infected file is identified.
""" """
# scan model # scan model
print(f" | Scanning Model: {model_name}") logger.debug(f"Scanning Model: {model_name}")
scan_result = scan_file_path(checkpoint) scan_result = scan_file_path(checkpoint)
if scan_result.infected_files != 0: if scan_result.infected_files != 0:
if scan_result.infected_files == 1: if scan_result.infected_files == 1:
@ -299,7 +324,7 @@ class ModelCache(object):
else: else:
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting") raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
else: else:
print(" | Model scanned ok") logger.debug("Model scanned ok")
def _load_ckpt_from_storage(self, def _load_ckpt_from_storage(self,
ckpt_path: Union[str,Path], ckpt_path: Union[str,Path],
@ -330,7 +355,7 @@ class ModelCache(object):
hash = f.read() hash = f.read()
return hash return hash
print(f' | computing hash of model {path.name}') logger.debug(f'computing hash of model {path.name}')
with open(path, "rb") as f: with open(path, "rb") as f:
while chunk := f.read(self.sha_chunksize): while chunk := f.read(self.sha_chunksize):
sha.update(chunk) sha.update(chunk)
@ -350,7 +375,7 @@ class ModelCache(object):
hash = f.read() hash = f.read()
return hash return hash
print(f' | computing hash of model {path.name}') logger.debug(f'computing hash of model {path.name}')
for file in list(path.rglob("*.ckpt")) \ for file in list(path.rglob("*.ckpt")) \
+ list(path.rglob("*.safetensors")) \ + list(path.rglob("*.safetensors")) \
+ list(path.rglob("*.pth")): + list(path.rglob("*.pth")):