fix compel conditioning object caching issue by applying deepcopy() before moving to VRAM

This commit is contained in:
Lincoln Stein 2024-07-18 14:53:03 -04:00
parent 5d6a77d336
commit 02957be333
9 changed files with 77 additions and 125 deletions

View File

@ -1,6 +1,7 @@
from typing import Iterator, List, Optional, Tuple, Union, cast
import torch
import threading
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@ -139,6 +140,7 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tid = threading.current_thread().ident
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
@ -205,6 +207,7 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)
conjunction = Compel.parse_prompt_string(prompt)
@ -315,7 +318,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput(

View File

@ -1,5 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect
import threading
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -192,10 +194,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = []
tid = threading.current_thread().ident
for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name)
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
@ -317,6 +319,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
tid = threading.current_thread().ident
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)

View File

@ -10,6 +10,7 @@ import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker
@ -46,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def load(self, name: str) -> T:
file_path = self._get_path(name)
try:
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e

View File

@ -53,25 +53,12 @@ class CacheRecord(Generic[T]):
key: Unique key for each model, same as used in the models database.
model: Read-only copy of the model *without weights* residing in the "meta device"
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str
size: int
model: T
state_dict: Optional[Dict[str, torch.Tensor]]
@dataclass

View File

@ -159,7 +159,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
device = free_device[0]
# we are outside the lock region now
self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}")
self.logger.info(f"{current_thread} Reserved torch device {device}")
# Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self)
@ -167,7 +167,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
yield device
finally:
with self._device_lock:
self.logger.info(f"Released torch device {device}")
self.logger.info(f"{current_thread} Released torch device {device}")
self._execution_devices[device] = 0
self._free_execution_device.release()
torch.cuda.empty_cache()
@ -215,18 +215,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
with self._ram_lock:
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
self.make_room(size)
if isinstance(model, torch.nn.Module):
state_dict = model.state_dict() # keep a master copy of the state dict
model = model.to(device="meta") # and keep a template in the meta device
else:
state_dict = None
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
tid = threading.current_thread().ident
cache_record = CacheRecord(key=key, model=model, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -296,11 +293,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError
"""
self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
with self._ram_lock:
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
# Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU
if cache_entry.state_dict is None:
if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to")
@ -309,65 +306,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
return cache_entry.model # what happens in CPU stays in CPU
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
assert isinstance(cache_entry.model, torch.nn.Module)
template = cache_entry.model
cls = template.__class__
with skip_torch_weight_init():
if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"):
working_model = template.__class__.from_config(template.config) # diffusers style
else:
working_model = template.__class__(config=template.config) # transformers style (sigh)
working_model.to(device=target_device, dtype=self._precision)
working_model.load_state_dict(cache_entry.state_dict)
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.info(
f"Moved model '{cache_entry.key}' to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
return working_model
def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
@ -445,8 +383,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
try:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
except ValueError:
pass
@staticmethod
def _device_name(device: torch.device) -> str:

View File

@ -31,10 +31,6 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
try:
@ -56,3 +52,9 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None:
"""Call upon exit from context."""
self._cache.print_cuda_stats()
# This is no longer in use in MGPU.
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None

View File

@ -2,7 +2,7 @@
"""These classes implement model patching with LoRAs and Textual Inversions."""
from __future__ import annotations
import threading
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@ -34,6 +34,9 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@ -106,7 +109,10 @@ class ModelPatcher:
"""
original_weights = {}
try:
with torch.no_grad():
with (
torch.no_grad(),
cls._thread_lock
):
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
@ -156,9 +162,6 @@ class ModelPatcher:
yield # wait for context manager exit
finally:
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
# Therefore it should not be necessary to copy the original model weights back.
# This needs to be fixed before resurrecting the VRAM cache.
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():

View File

@ -1,4 +1,5 @@
import math
import threading
from dataclasses import dataclass
from typing import List, Optional, Union
@ -31,9 +32,13 @@ class SDXLConditioningInfo(BasicConditioningInfo):
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
tid = threading.current_thread().ident
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
result = super().to(device=device, dtype=dtype)
assert self.embeds.device == device
return result
@dataclass

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import math
import threading
from typing import Any, Callable, Optional, Union
import torch
@ -293,7 +294,10 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None
try:
if conditioning_data.is_sdxl():
#tid = threading.current_thread().ident
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
added_cond_kwargs = {
"text_embeds": torch.cat(
[
@ -311,6 +315,10 @@ class InvokeAIDiffuserComponent:
dim=0,
),
}
except Exception as e:
tid = threading.current_thread().ident
print(f'DEBUG: {tid} {str(e)}')
raise e
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings