draft sd3 loading; probable VRAM leak when using quantized submodels

This commit is contained in:
Lincoln Stein 2024-06-13 00:51:00 -04:00
parent 002f8242a1
commit 03b9d17d0b
5 changed files with 32 additions and 7 deletions

View File

@ -79,6 +79,7 @@ class SubModelType(str, Enum):
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2" Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3" Tokenizer3 = "tokenizer_3"
Transformer = "transformer"
VAE = "vae" VAE = "vae"
VAEDecoder = "vae_decoder" VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder" VAEEncoder = "vae_encoder"

View File

@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
device: torch.device device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]] state_dict: Optional[Dict[str, torch.Tensor]]
size: int size: int
is_quantized: bool = False
loaded: bool = False loaded: bool = False
_locks: int = 0 _locks: int = 0

View File

@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
execution_device: torch.device = torch.device("cuda"), execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"), storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True, lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False, log_memory_usage: bool = False,
logger: Optional[Logger] = None, logger: Optional[Logger] = None,
): ):
@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16] :param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@ -163,8 +160,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model) size = calc_model_size_by_data(model)
self.make_room(size) self.make_room(size)
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) cache_record = CacheRecord(
key=key,
model=model,
device=self._storage_device,
is_quantized=is_quantized,
state_dict=state_dict,
size=size,
)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -230,12 +235,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
reserved = self._max_vram_cache_size * GIG reserved = self._max_vram_cache_size * GIG
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
delete_it = False
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved: if vram_in_use <= reserved:
break break
if not cache_entry.loaded: if not cache_entry.loaded:
continue continue
if not cache_entry.locked: if not cache_entry.locked:
if cache_entry.is_quantized:
self._delete_cache_entry(cache_entry)
delete_it = True
else:
self.move_model_to_device(cache_entry, self.storage_device) self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required vram_in_use = torch.cuda.memory_allocated() + size_required
@ -243,6 +253,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
) )
if delete_it:
del cache_entry
TorchDevice.empty_cache() TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
@ -265,6 +277,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not hasattr(cache_entry.model, "to"): if not hasattr(cache_entry.model, "to"):
return return
if cache_entry.is_quantized: # can't move quantized models around
return
# This roundabout method for moving the model around is done to avoid # This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from # When moving to VRAM, we copy (not move) each element of the state dict from
@ -407,3 +422,5 @@ class ModelCache(ModelCacheBase[AnyModel]):
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key) self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key] del self._cached_models[cache_entry.key]
gc.collect()
TorchDevice.empty_cache()

View File

@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader): class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models.""" """Class to load main models."""
# note - will be removed for load_single_file()
model_base_to_model_type = { model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only
BaseModelType.StableDiffusionXL: "SDXL", BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
} }
@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if variant and "no file named" in str( if variant and "no file named" in str(
e e
): # try without the variant, just in case user's preferences changed ): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype) result = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
)
else: else:
raise e raise e

View File

@ -34,6 +34,7 @@ classifiers = [
dependencies = [ dependencies = [
# Core generation dependencies, pinned for reproducible builds. # Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1", "accelerate==0.30.1",
"bitsandbytes",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2", "compel==2.0.2",
"controlnet-aux==0.0.7", "controlnet-aux==0.0.7",