mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Optimize RAM to VRAM transfer (#6312)
* avoid copying model back from cuda to cpu * handle models that don't have state dicts * add assertions that models need a `device()` method * do not rely on torch.nn.Module having the device() method * apply all patches after model is on the execution device * fix model patching in latents too * log patched tokenizer * closes #6375 --------- Co-authored-by: Lincoln Stein <lstein@gmail.com>
This commit is contained in:
parent
7437085cac
commit
532f82cb97
@ -65,11 +65,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
|
||||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
text_encoder_model = text_encoder_info.model
|
|
||||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
@ -84,19 +80,21 @@ class CompelInvocation(BaseInvocation):
|
|||||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
# apply all patches while the model is on the target device
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
),
|
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
|
patched_tokenizer,
|
||||||
|
ti_manager,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(text_encoder, CLIPTextModel)
|
assert isinstance(text_encoder, CLIPTextModel)
|
||||||
|
assert isinstance(tokenizer, CLIPTokenizer)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=patched_tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
@ -106,7 +104,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||||
|
|
||||||
if context.config.get().log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||||
|
|
||||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
@ -136,11 +134,7 @@ class SDXLPromptInvocationBase:
|
|||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
tokenizer_model = tokenizer_info.model
|
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
|
||||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||||
text_encoder_model = text_encoder_info.model
|
|
||||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
@ -177,20 +171,23 @@ class SDXLPromptInvocationBase:
|
|||||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
# apply all patches while the model is on the target device
|
||||||
tokenizer,
|
|
||||||
ti_manager,
|
|
||||||
),
|
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
|
patched_tokenizer,
|
||||||
|
ti_manager,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
assert isinstance(tokenizer, CLIPTokenizer)
|
||||||
|
|
||||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=patched_tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
textual_inversion_manager=ti_manager,
|
textual_inversion_manager=ti_manager,
|
||||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||||
@ -203,7 +200,7 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
if context.config.get().log_tokenization:
|
if context.config.get().log_tokenization:
|
||||||
# TODO: better logging for and syntax
|
# TODO: better logging for and syntax
|
||||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||||
|
|
||||||
# TODO: ask for optimizations? to not run text_encoder twice
|
# TODO: ask for optimizations? to not run text_encoder twice
|
||||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
|
|
||||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
|
@ -42,10 +42,26 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheRecord(Generic[T]):
|
class CacheRecord(Generic[T]):
|
||||||
"""Elements of the cache."""
|
"""
|
||||||
|
Elements of the cache:
|
||||||
|
|
||||||
|
key: Unique key for each model, same as used in the models database.
|
||||||
|
model: Model in memory.
|
||||||
|
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||||
|
used as a template for creating a copy in the VRAM.
|
||||||
|
size: Size of the model
|
||||||
|
loaded: True if the model's state dict is currently in VRAM
|
||||||
|
|
||||||
|
Before a model is executed, the state_dict template is copied into VRAM,
|
||||||
|
and then injected into the model. When the model is finished, the VRAM
|
||||||
|
copy of the state dict is deleted, and the RAM version is reinjected
|
||||||
|
into the model.
|
||||||
|
"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
model: T
|
model: T
|
||||||
|
device: torch.device
|
||||||
|
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||||
size: int
|
size: int
|
||||||
loaded: bool = False
|
loaded: bool = False
|
||||||
_locks: int = 0
|
_locks: int = 0
|
||||||
|
@ -20,7 +20,6 @@ context. Use like this:
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
@ -162,7 +161,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if key in self._cached_models:
|
if key in self._cached_models:
|
||||||
return
|
return
|
||||||
self.make_room(size)
|
self.make_room(size)
|
||||||
cache_record = CacheRecord(key, model, size)
|
|
||||||
|
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||||
|
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||||
self._cached_models[key] = cache_record
|
self._cached_models[key] = cache_record
|
||||||
self._cache_stack.append(key)
|
self._cache_stack.append(key)
|
||||||
|
|
||||||
@ -257,17 +258,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||||
return
|
return
|
||||||
|
|
||||||
source_device = cache_entry.model.device
|
source_device = cache_entry.device
|
||||||
|
|
||||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||||
# This would need to be revised to support multi-GPU.
|
# This would need to be revised to support multi-GPU.
|
||||||
if torch.device(source_device).type == torch.device(target_device).type:
|
if torch.device(source_device).type == torch.device(target_device).type:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# This roundabout method for moving the model around is done to avoid
|
||||||
|
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||||
|
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||||
|
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||||
|
# This operation is slightly faster than running `to()` on the whole model.
|
||||||
|
#
|
||||||
|
# When the model needs to be removed from VRAM we simply delete the copy
|
||||||
|
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||||
|
# in RAM into the model. So this operation is very fast.
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if cache_entry.state_dict is not None:
|
||||||
|
assert hasattr(cache_entry.model, "load_state_dict")
|
||||||
|
if target_device == self.storage_device:
|
||||||
|
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||||
|
else:
|
||||||
|
new_dict: Dict[str, torch.Tensor] = {}
|
||||||
|
for k, v in cache_entry.state_dict.items():
|
||||||
|
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
||||||
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
|
cache_entry.device = target_device
|
||||||
except Exception as e: # blow away cache entry
|
except Exception as e: # blow away cache entry
|
||||||
self._delete_cache_entry(cache_entry)
|
self._delete_cache_entry(cache_entry)
|
||||||
raise e
|
raise e
|
||||||
@ -347,43 +368,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||||
model_key = self._cache_stack[pos]
|
model_key = self._cache_stack[pos]
|
||||||
cache_entry = self._cached_models[model_key]
|
cache_entry = self._cached_models[model_key]
|
||||||
|
|
||||||
refs = sys.getrefcount(cache_entry.model)
|
|
||||||
|
|
||||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
|
||||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
|
||||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
|
||||||
|
|
||||||
# manualy clear local variable references of just finished function calls
|
|
||||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
|
||||||
if refs > 2:
|
|
||||||
while True:
|
|
||||||
cleared = False
|
|
||||||
for referrer in gc.get_referrers(cache_entry.model):
|
|
||||||
if type(referrer).__name__ == "frame":
|
|
||||||
# RuntimeError: cannot clear an executing frame
|
|
||||||
with suppress(RuntimeError):
|
|
||||||
referrer.clear()
|
|
||||||
cleared = True
|
|
||||||
# break
|
|
||||||
|
|
||||||
# repeat if referrers changes(due to frame clear), else exit loop
|
|
||||||
if cleared:
|
|
||||||
gc.collect()
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||||
f" refs: {refs}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expected refs:
|
if not cache_entry.locked:
|
||||||
# 1 from cache_entry
|
|
||||||
# 1 from getrefcount function
|
|
||||||
# 1 from onnx runtime object
|
|
||||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user