mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement lazy GPU offloading and ref counting
This commit is contained in:
parent
a273bdbdc1
commit
68bc0112fa
@ -21,6 +21,7 @@ import gc
|
||||
import hashlib
|
||||
import warnings
|
||||
from collections.abc import Generator
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union
|
||||
@ -36,6 +37,7 @@ from pydantic import BaseModel
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from ..globals import global_cache_dir
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from . import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
@ -77,6 +79,7 @@ class ModelCache(object):
|
||||
storage_device: torch.device=torch.device('cpu'),
|
||||
precision: torch.dtype=torch.float16,
|
||||
sequential_offload: bool=False,
|
||||
lazy_offloading: bool=True,
|
||||
sha_chunksize: int = 16777216,
|
||||
):
|
||||
'''
|
||||
@ -84,17 +87,21 @@ class ModelCache(object):
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
'''
|
||||
self.models: dict = dict()
|
||||
self.stack: Sequence = list()
|
||||
self.lazy_offloading = lazy_offloading
|
||||
self.sequential_offload: bool=sequential_offload
|
||||
self.precision: torch.dtype=precision
|
||||
self.max_models: int=max_models
|
||||
self.execution_device: torch.device=execution_device
|
||||
self.storage_device: torch.device=storage_device
|
||||
self.sha_chunksize=sha_chunksize
|
||||
self.loaded_models: set = set() # set of model keys loaded in GPU
|
||||
self.locked_models: Counter = Counter() # set of model keys locked in GPU
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_model(
|
||||
@ -149,21 +156,39 @@ class ModelCache(object):
|
||||
|
||||
if submodel:
|
||||
model = getattr(model, submodel.name)
|
||||
debugging_name = f'{submodel.name} submodel of {repo_id_or_path}'
|
||||
else:
|
||||
debugging_name = repo_id_or_path
|
||||
|
||||
try:
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
print(f' | Loading {debugging_name} into GPU')
|
||||
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
try:
|
||||
self.loaded_models.add(key)
|
||||
self.locked_models[key] += 1
|
||||
if self.lazy_offloading:
|
||||
self._offload_unlocked_models()
|
||||
logger.debug(f'Loading {key} into {self.execution_device}')
|
||||
model.to(self.execution_device) # move into GPU
|
||||
self._print_cuda_stats()
|
||||
yield model
|
||||
finally:
|
||||
if gpu_load and hasattr(model,'to'):
|
||||
print(f' | Unloading {debugging_name} from GPU')
|
||||
model.to(self.storage_device)
|
||||
yield model
|
||||
finally:
|
||||
self.locked_models[key] -= 1
|
||||
if not self.lazy_offloading:
|
||||
self._offload_unlocked_models()
|
||||
self._print_cuda_stats()
|
||||
else:
|
||||
# in the event that the caller wants the model in RAM, we
|
||||
# move it into CPU if it is in GPU and not locked
|
||||
if hasattr(model,'to') and (key in self.loaded_models
|
||||
and self.locked_models[key] == 0):
|
||||
model.to(self.storage_device)
|
||||
yield model
|
||||
|
||||
def _offload_unlocked_models(self):
|
||||
to_offload = set()
|
||||
for key in self.loaded_models:
|
||||
if key not in self.locked_models or self.locked_models[key] == 0:
|
||||
logger.debug(f'Offloading {key} from {self.execution_device} into {self.storage_device}')
|
||||
to_offload.add(key)
|
||||
for key in to_offload:
|
||||
self.models[key].to(self.storage_device)
|
||||
self.loaded_models.remove(key)
|
||||
|
||||
def model_hash(self,
|
||||
repo_id_or_path: Union[str,Path],
|
||||
@ -193,16 +218,16 @@ class ModelCache(object):
|
||||
return self.execution_device.type == 'cuda'
|
||||
|
||||
def _print_cuda_stats(self):
|
||||
print(
|
||||
" | Current VRAM usage:",
|
||||
"%4.2fG" % (torch.cuda.memory_allocated() / 1e9),
|
||||
)
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)
|
||||
loaded_models = len(self.loaded_models)
|
||||
locked_models = len([x for x in self.locked_models if self.locked_models[x]>0])
|
||||
logger.debug(f"Current VRAM usage: {vram}; locked_models/loaded_models = {locked_models}/{loaded_models}")
|
||||
|
||||
def _make_cache_room(self):
|
||||
models_in_ram = len(self.models)
|
||||
while models_in_ram >= self.max_models:
|
||||
if least_recently_used_key := self.stack.pop(0):
|
||||
print(f' | Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
logger.debug(f'Maximum cache size reached: cache_size={models_in_ram}; unloading model {least_recently_used_key}')
|
||||
del self.models[least_recently_used_key]
|
||||
models_in_ram = len(self.models)
|
||||
gc.collect()
|
||||
@ -291,7 +316,7 @@ class ModelCache(object):
|
||||
and option to exit if an infected file is identified.
|
||||
"""
|
||||
# scan model
|
||||
print(f" | Scanning Model: {model_name}")
|
||||
logger.debug(f"Scanning Model: {model_name}")
|
||||
scan_result = scan_file_path(checkpoint)
|
||||
if scan_result.infected_files != 0:
|
||||
if scan_result.infected_files == 1:
|
||||
@ -299,7 +324,7 @@ class ModelCache(object):
|
||||
else:
|
||||
raise UnscannableModelException("InvokeAI was unable to scan the legacy model you requested. Aborting")
|
||||
else:
|
||||
print(" | Model scanned ok")
|
||||
logger.debug("Model scanned ok")
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
ckpt_path: Union[str,Path],
|
||||
@ -330,7 +355,7 @@ class ModelCache(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f' | computing hash of model {path.name}')
|
||||
logger.debug(f'computing hash of model {path.name}')
|
||||
with open(path, "rb") as f:
|
||||
while chunk := f.read(self.sha_chunksize):
|
||||
sha.update(chunk)
|
||||
@ -350,7 +375,7 @@ class ModelCache(object):
|
||||
hash = f.read()
|
||||
return hash
|
||||
|
||||
print(f' | computing hash of model {path.name}')
|
||||
logger.debug(f'computing hash of model {path.name}')
|
||||
for file in list(path.rglob("*.ckpt")) \
|
||||
+ list(path.rglob("*.safetensors")) \
|
||||
+ list(path.rglob("*.pth")):
|
||||
|
Loading…
Reference in New Issue
Block a user