draft sd3 loading; probable VRAM leak when using quantized submodels

This commit is contained in:
Lincoln Stein 2024-06-13 00:51:00 -04:00
parent 002f8242a1
commit 03b9d17d0b
5 changed files with 32 additions and 7 deletions

View File

@ -79,6 +79,7 @@ class SubModelType(str, Enum):
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
Transformer = "transformer"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"

View File

@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
is_quantized: bool = False
loaded: bool = False
_locks: int = 0

View File

@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -163,8 +160,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model)
self.make_room(size)
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
cache_record = CacheRecord(
key=key,
model=model,
device=self._storage_device,
is_quantized=is_quantized,
state_dict=state_dict,
size=size,
)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -230,19 +235,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
delete_it = False
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
if cache_entry.is_quantized:
self._delete_cache_entry(cache_entry)
delete_it = True
else:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
if delete_it:
del cache_entry
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
@ -265,6 +277,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not hasattr(cache_entry.model, "to"):
return
if cache_entry.is_quantized: # can't move quantized models around
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
@ -407,3 +422,5 @@ class ModelCache(ModelCacheBase[AnyModel]):
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
gc.collect()
TorchDevice.empty_cache()

View File

@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
# note - will be removed for load_single_file()
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
result = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
)
else:
raise e

View File

@ -34,6 +34,7 @@ classifiers = [
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1",
"bitsandbytes",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",