mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
draft sd3 loading; probable VRAM leak when using quantized submodels
This commit is contained in:
parent
002f8242a1
commit
03b9d17d0b
@ -79,6 +79,7 @@ class SubModelType(str, Enum):
|
|||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
Tokenizer3 = "tokenizer_3"
|
Tokenizer3 = "tokenizer_3"
|
||||||
|
Transformer = "transformer"
|
||||||
VAE = "vae"
|
VAE = "vae"
|
||||||
VAEDecoder = "vae_decoder"
|
VAEDecoder = "vae_decoder"
|
||||||
VAEEncoder = "vae_encoder"
|
VAEEncoder = "vae_encoder"
|
||||||
|
@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
|
|||||||
device: torch.device
|
device: torch.device
|
||||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||||
size: int
|
size: int
|
||||||
|
is_quantized: bool = False
|
||||||
loaded: bool = False
|
loaded: bool = False
|
||||||
_locks: int = 0
|
_locks: int = 0
|
||||||
|
|
||||||
|
@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
execution_device: torch.device = torch.device("cuda"),
|
execution_device: torch.device = torch.device("cuda"),
|
||||||
storage_device: torch.device = torch.device("cpu"),
|
storage_device: torch.device = torch.device("cpu"),
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float16,
|
||||||
sequential_offload: bool = False,
|
|
||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
sha_chunksize: int = 16777216,
|
|
||||||
log_memory_usage: bool = False,
|
log_memory_usage: bool = False,
|
||||||
logger: Optional[Logger] = None,
|
logger: Optional[Logger] = None,
|
||||||
):
|
):
|
||||||
@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||||
:param precision: Precision for loaded models [torch.float16]
|
:param precision: Precision for loaded models [torch.float16]
|
||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
|
||||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
@ -163,8 +160,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
size = calc_model_size_by_data(model)
|
size = calc_model_size_by_data(model)
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
|
|
||||||
|
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
||||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
cache_record = CacheRecord(
|
||||||
|
key=key,
|
||||||
|
model=model,
|
||||||
|
device=self._storage_device,
|
||||||
|
is_quantized=is_quantized,
|
||||||
|
state_dict=state_dict,
|
||||||
|
size=size,
|
||||||
|
)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@ -230,19 +235,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
reserved = self._max_vram_cache_size * GIG
|
reserved = self._max_vram_cache_size * GIG
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
delete_it = False
|
||||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||||
if vram_in_use <= reserved:
|
if vram_in_use <= reserved:
|
||||||
break
|
break
|
||||||
if not cache_entry.loaded:
|
if not cache_entry.loaded:
|
||||||
continue
|
continue
|
||||||
if not cache_entry.locked:
|
if not cache_entry.locked:
|
||||||
self.move_model_to_device(cache_entry, self.storage_device)
|
if cache_entry.is_quantized:
|
||||||
cache_entry.loaded = False
|
self._delete_cache_entry(cache_entry)
|
||||||
|
delete_it = True
|
||||||
|
else:
|
||||||
|
self.move_model_to_device(cache_entry, self.storage_device)
|
||||||
|
cache_entry.loaded = False
|
||||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if delete_it:
|
||||||
|
del cache_entry
|
||||||
TorchDevice.empty_cache()
|
TorchDevice.empty_cache()
|
||||||
|
|
||||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||||
@ -265,6 +277,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if not hasattr(cache_entry.model, "to"):
|
if not hasattr(cache_entry.model, "to"):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if cache_entry.is_quantized: # can't move quantized models around
|
||||||
|
return
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
# This roundabout method for moving the model around is done to avoid
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||||
@ -407,3 +422,5 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
self._cache_stack.remove(cache_entry.key)
|
self._cache_stack.remove(cache_entry.key)
|
||||||
del self._cached_models[cache_entry.key]
|
del self._cached_models[cache_entry.key]
|
||||||
|
gc.collect()
|
||||||
|
TorchDevice.empty_cache()
|
||||||
|
@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
|||||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load main models."""
|
"""Class to load main models."""
|
||||||
|
|
||||||
|
# note - will be removed for load_single_file()
|
||||||
model_base_to_model_type = {
|
model_base_to_model_type = {
|
||||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||||
|
BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only
|
||||||
BaseModelType.StableDiffusionXL: "SDXL",
|
BaseModelType.StableDiffusionXL: "SDXL",
|
||||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||||
}
|
}
|
||||||
@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
|||||||
if variant and "no file named" in str(
|
if variant and "no file named" in str(
|
||||||
e
|
e
|
||||||
): # try without the variant, just in case user's preferences changed
|
): # try without the variant, just in case user's preferences changed
|
||||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
result = load_class.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
torch_dtype=self._torch_dtype,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ classifiers = [
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
# Core generation dependencies, pinned for reproducible builds.
|
# Core generation dependencies, pinned for reproducible builds.
|
||||||
"accelerate==0.30.1",
|
"accelerate==0.30.1",
|
||||||
|
"bitsandbytes",
|
||||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
"compel==2.0.2",
|
"compel==2.0.2",
|
||||||
"controlnet-aux==0.0.7",
|
"controlnet-aux==0.0.7",
|
||||||
|
Loading…
Reference in New Issue
Block a user