mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
draft sd3 loading; probable VRAM leak when using quantized submodels
This commit is contained in:
parent
002f8242a1
commit
03b9d17d0b
@ -79,6 +79,7 @@ class SubModelType(str, Enum):
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
Transformer = "transformer"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
|
@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
is_quantized: bool = False
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
|
@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
@ -163,8 +160,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
cache_record = CacheRecord(
|
||||
key=key,
|
||||
model=model,
|
||||
device=self._storage_device,
|
||||
is_quantized=is_quantized,
|
||||
state_dict=state_dict,
|
||||
size=size,
|
||||
)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@ -230,19 +235,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
delete_it = False
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
if cache_entry.is_quantized:
|
||||
self._delete_cache_entry(cache_entry)
|
||||
delete_it = True
|
||||
else:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
if delete_it:
|
||||
del cache_entry
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
@ -265,6 +277,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
if cache_entry.is_quantized: # can't move quantized models around
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
@ -407,3 +422,5 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
gc.collect()
|
||||
TorchDevice.empty_cache()
|
||||
|
@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
# note - will be removed for load_single_file()
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
||||
result = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
@ -34,6 +34,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==0.30.1",
|
||||
"bitsandbytes",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
|
Loading…
Reference in New Issue
Block a user