fix compel conditioning object caching issue by applying deepcopy() before moving to VRAM

This commit is contained in:
Lincoln Stein 2024-07-18 14:53:03 -04:00
parent 5d6a77d336
commit 02957be333
9 changed files with 77 additions and 125 deletions

View File

@ -1,6 +1,7 @@
from typing import Iterator, List, Optional, Tuple, Union, cast from typing import Iterator, List, Optional, Tuple, Union, cast
import torch import torch
import threading
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@ -139,6 +140,7 @@ class SDXLPromptInvocationBase:
lora_prefix: str, lora_prefix: str,
zero_on_empty: bool, zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tid = threading.current_thread().ident
tokenizer_info = context.models.load(clip_field.tokenizer) tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder) text_encoder_info = context.models.load(clip_field.text_encoder)
@ -205,6 +207,7 @@ class SDXLPromptInvocationBase:
truncate_long_prompts=False, # TODO: truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled, requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
) )
conjunction = Compel.parse_prompt_string(prompt) conjunction = Compel.parse_prompt_string(prompt)
@ -315,7 +318,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
) )
] ]
) )
conditioning_name = context.conditioning.save(conditioning_data) conditioning_name = context.conditioning.save(conditioning_data)
return ConditioningOutput( return ConditioningOutput(

View File

@ -1,5 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import copy
import inspect import inspect
import threading
from contextlib import ExitStack from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@ -192,10 +194,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Get the text embeddings and masks from the input conditioning fields.""" """Get the text embeddings and masks from the input conditioning fields."""
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
text_embeddings_masks: list[Optional[torch.Tensor]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = []
tid = threading.current_thread().ident
for cond in cond_list: for cond in cond_list:
cond_data = context.conditioning.load(cond.conditioning_name) cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
mask = cond.mask mask = cond.mask
if mask is not None: if mask is not None:
mask = context.tensors.load(mask.tensor_name) mask = context.tensors.load(mask.tensor_name)
@ -317,6 +319,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(uncond_list, list): if not isinstance(uncond_list, list):
uncond_list = [uncond_list] uncond_list = [uncond_list]
tid = threading.current_thread().ident
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype cond_list, context, unet.device, unet.dtype
) )

View File

@ -10,6 +10,7 @@ import torch
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.util.devices import TorchDevice
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
@ -46,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
def load(self, name: str) -> T: def load(self, name: str) -> T:
file_path = self._get_path(name) file_path = self._get_path(name)
try: try:
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
except FileNotFoundError as e: except FileNotFoundError as e:
raise ObjectNotFoundError(name) from e raise ObjectNotFoundError(name) from e

View File

@ -53,25 +53,12 @@ class CacheRecord(Generic[T]):
key: Unique key for each model, same as used in the models database. key: Unique key for each model, same as used in the models database.
model: Read-only copy of the model *without weights* residing in the "meta device" model: Read-only copy of the model *without weights* residing in the "meta device"
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model size: Size of the model
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
""" """
key: str key: str
size: int size: int
model: T model: T
state_dict: Optional[Dict[str, torch.Tensor]]
@dataclass @dataclass

View File

@ -159,7 +159,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
device = free_device[0] device = free_device[0]
# we are outside the lock region now # we are outside the lock region now
self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}") self.logger.info(f"{current_thread} Reserved torch device {device}")
# Tell TorchDevice to use this object to get the torch device. # Tell TorchDevice to use this object to get the torch device.
TorchDevice.set_model_cache(self) TorchDevice.set_model_cache(self)
@ -167,7 +167,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
yield device yield device
finally: finally:
with self._device_lock: with self._device_lock:
self.logger.info(f"Released torch device {device}") self.logger.info(f"{current_thread} Released torch device {device}")
self._execution_devices[device] = 0 self._execution_devices[device] = 0
self._free_execution_device.release() self._free_execution_device.release()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -215,20 +215,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
submodel_type: Optional[SubModelType] = None, submodel_type: Optional[SubModelType] = None,
) -> None: ) -> None:
"""Store model under key and optional submodel_type.""" """Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type) with self._ram_lock:
if key in self._cached_models: key = self._make_cache_key(key, submodel_type)
return if key in self._cached_models:
size = calc_model_size_by_data(model) return
self.make_room(size) size = calc_model_size_by_data(model)
self.make_room(size)
if isinstance(model, torch.nn.Module): tid = threading.current_thread().ident
state_dict = model.state_dict() # keep a master copy of the state dict cache_record = CacheRecord(key=key, model=model, size=size)
model = model.to(device="meta") # and keep a template in the meta device self._cached_models[key] = cache_record
else: self._cache_stack.append(key)
state_dict = None
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
def get( def get(
self, self,
@ -296,11 +293,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
May raise a torch.cuda.OutOfMemoryError May raise a torch.cuda.OutOfMemoryError
""" """
self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}") with self._ram_lock:
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
# Some models don't have a state dictionary, in which case the # Some models don't have a state dictionary, in which case the
# stored model will still reside in CPU # stored model will still reside in CPU
if cache_entry.state_dict is None:
if hasattr(cache_entry.model, "to"): if hasattr(cache_entry.model, "to"):
model_in_gpu = copy.deepcopy(cache_entry.model) model_in_gpu = copy.deepcopy(cache_entry.model)
assert hasattr(model_in_gpu, "to") assert hasattr(model_in_gpu, "to")
@ -309,65 +306,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
return cache_entry.model # what happens in CPU stays in CPU return cache_entry.model # what happens in CPU stays in CPU
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
assert isinstance(cache_entry.model, torch.nn.Module)
template = cache_entry.model
cls = template.__class__
with skip_torch_weight_init():
if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"):
working_model = template.__class__.from_config(template.config) # diffusers style
else:
working_model = template.__class__(config=template.config) # transformers style (sigh)
working_model.to(device=target_device, dtype=self._precision)
working_model.load_state_dict(cache_entry.state_dict)
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.info(
f"Moved model '{cache_entry.key}' to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.debug(
f"Moving model '{cache_entry.key}' from to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
return working_model
def print_cuda_stats(self) -> None: def print_cuda_stats(self) -> None:
"""Log CUDA diagnostics.""" """Log CUDA diagnostics."""
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
@ -445,8 +383,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
raise torch.cuda.OutOfMemoryError raise torch.cuda.OutOfMemoryError
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key) try:
del self._cached_models[cache_entry.key] self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
except ValueError:
pass
@staticmethod @staticmethod
def _device_name(device: torch.device) -> str: def _device_name(device: torch.device) -> str:

View File

@ -31,10 +31,6 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around.""" """Return the model without moving it around."""
return self._cache_entry.model return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel: def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it.""" """Move the model into the execution device (GPU) and lock it."""
try: try:
@ -56,3 +52,9 @@ class ModelLocker(ModelLockerBase):
def unlock(self) -> None: def unlock(self) -> None:
"""Call upon exit from context.""" """Call upon exit from context."""
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
# This is no longer in use in MGPU.
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return None

View File

@ -2,7 +2,7 @@
"""These classes implement model patching with LoRAs and Textual Inversions.""" """These classes implement model patching with LoRAs and Textual Inversions."""
from __future__ import annotations from __future__ import annotations
import threading
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
@ -34,6 +34,9 @@ with LoRAHelper.apply_lora_unet(unet, loras):
# TODO: rename smth like ModelPatcher and add TI method? # TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher: class ModelPatcher:
_thread_lock = threading.Lock()
@staticmethod @staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key assert "." not in lora_key
@ -106,7 +109,10 @@ class ModelPatcher:
""" """
original_weights = {} original_weights = {}
try: try:
with torch.no_grad(): with (
torch.no_grad(),
cls._thread_lock
):
for lora, lora_weight in loras: for lora, lora_weight in loras:
# assert lora.device.type == "cpu" # assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items(): for layer_key, layer in lora.layers.items():
@ -156,9 +162,6 @@ class ModelPatcher:
yield # wait for context manager exit yield # wait for context manager exit
finally: finally:
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
# Therefore it should not be necessary to copy the original model weights back.
# This needs to be fixed before resurrecting the VRAM cache.
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():

View File

@ -1,4 +1,5 @@
import math import math
import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
@ -31,9 +32,13 @@ class SDXLConditioningInfo(BasicConditioningInfo):
add_time_ids: torch.Tensor add_time_ids: torch.Tensor
def to(self, device, dtype=None): def to(self, device, dtype=None):
tid = threading.current_thread().ident
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
assert self.pooled_embeds.device == device
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype) result = super().to(device=device, dtype=dtype)
assert self.embeds.device == device
return result
@dataclass @dataclass

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import math import math
import threading
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
@ -293,24 +294,31 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs["regional_ip_data"] = regional_ip_data cross_attention_kwargs["regional_ip_data"] = regional_ip_data
added_cond_kwargs = None added_cond_kwargs = None
if conditioning_data.is_sdxl(): try:
added_cond_kwargs = { if conditioning_data.is_sdxl():
"text_embeds": torch.cat( #tid = threading.current_thread().ident
[ #print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
# TODO: how to pad? just by zeros? or even truncate? added_cond_kwargs = {
conditioning_data.uncond_text.pooled_embeds, "text_embeds": torch.cat(
conditioning_data.cond_text.pooled_embeds, [
], # TODO: how to pad? just by zeros? or even truncate?
dim=0, conditioning_data.uncond_text.pooled_embeds,
), conditioning_data.cond_text.pooled_embeds,
"time_ids": torch.cat( ],
[ dim=0,
conditioning_data.uncond_text.add_time_ids, ),
conditioning_data.cond_text.add_time_ids, "time_ids": torch.cat(
], [
dim=0, conditioning_data.uncond_text.add_time_ids,
), conditioning_data.cond_text.add_time_ids,
} ],
dim=0,
),
}
except Exception as e:
tid = threading.current_thread().ident
print(f'DEBUG: {tid} {str(e)}')
raise e
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings