Optimize RAM to VRAM transfer (#6312)

* avoid copying model back from cuda to cpu

* handle models that don't have state dicts

* add assertions that models need a `device()` method

* do not rely on torch.nn.Module having the device() method

* apply all patches after model is on the execution device

* fix model patching in latents too

* log patched tokenizer

* closes #6375

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
Lincoln Stein 2024-05-24 13:06:09 -04:00 committed by GitHub
parent 7437085cac
commit 532f82cb97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 63 deletions

View File

@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput: def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer) tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder) text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras: for lora in self.clip.loras:
@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, CLIPTextModel) assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt) conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@ -136,11 +134,7 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty # return zero on empty
if prompt == "" and zero_on_empty: if prompt == "" and zero_on_empty:
@ -177,20 +171,23 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with ( with (
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( # apply all patches while the model is on the target device
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching. tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ti_manager,
),
): ):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)) assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder) text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=patched_tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@ -203,7 +200,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization: if context.config.get().log_tokenization:
# TODO: better logging for and syntax # TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, tokenizer) log_tokenization_for_conjunction(conjunction, patched_tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice # TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet, unet_info as unet,
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(unet, _lora_loader()),
): ):

View File

@ -42,10 +42,26 @@ T = TypeVar("T")
@dataclass @dataclass
class CacheRecord(Generic[T]): class CacheRecord(Generic[T]):
"""Elements of the cache.""" """
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
key: str key: str
model: T model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int size: int
loaded: bool = False loaded: bool = False
_locks: int = 0 _locks: int = 0

View File

@ -20,7 +20,6 @@ context. Use like this:
import gc import gc
import math import math
import sys
import time import time
from contextlib import suppress from contextlib import suppress
from logging import Logger from logging import Logger
@ -162,7 +161,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models: if key in self._cached_models:
return return
self.make_room(size) self.make_room(size)
cache_record = CacheRecord(key, model, size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record self._cached_models[key] = cache_record
self._cache_stack.append(key) self._cache_stack.append(key)
@ -257,17 +258,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")): if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return return
source_device = cache_entry.model.device source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. # Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU. # This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type: if torch.device(source_device).type == torch.device(target_device).type:
return return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot() snapshot_before = self._capture_memory_snapshot()
try: try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device) cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)
raise e raise e
@ -347,43 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos] model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug( self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}," f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
f" refs: {refs}"
) )
# Expected refs: if not cache_entry.locked:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )