diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 4a56730e05..f860b21dec 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,6 +1,7 @@ from typing import Iterator, List, Optional, Tuple, Union, cast import torch +import threading from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -139,6 +140,7 @@ class SDXLPromptInvocationBase: lora_prefix: str, zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + tid = threading.current_thread().ident tokenizer_info = context.models.load(clip_field.tokenizer) text_encoder_info = context.models.load(clip_field.text_encoder) @@ -205,6 +207,7 @@ class SDXLPromptInvocationBase: truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, + device=TorchDevice.choose_torch_device(), ) conjunction = Compel.parse_prompt_string(prompt) @@ -315,7 +318,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ) ] ) - conditioning_name = context.conditioning.save(conditioning_data) return ConditioningOutput( diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 13a92efab8..a58fb69cb3 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1,5 +1,7 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +import copy import inspect +import threading from contextlib import ExitStack from typing import Any, Dict, Iterator, List, Optional, Tuple, Union @@ -192,10 +194,10 @@ class DenoiseLatentsInvocation(BaseInvocation): """Get the text embeddings and masks from the input conditioning fields.""" text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = [] text_embeddings_masks: list[Optional[torch.Tensor]] = [] + tid = threading.current_thread().ident for cond in cond_list: - cond_data = context.conditioning.load(cond.conditioning_name) + cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name)) text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype)) - mask = cond.mask if mask is not None: mask = context.tensors.load(mask.tensor_name) @@ -317,6 +319,7 @@ class DenoiseLatentsInvocation(BaseInvocation): if not isinstance(uncond_list, list): uncond_list = [uncond_list] + tid = threading.current_thread().ident cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_list, context, unet.device, unet.dtype ) diff --git a/invokeai/app/services/object_serializer/object_serializer_disk.py b/invokeai/app/services/object_serializer/object_serializer_disk.py index d3171f8530..e27ea46e6b 100644 --- a/invokeai/app/services/object_serializer/object_serializer_disk.py +++ b/invokeai/app/services/object_serializer/object_serializer_disk.py @@ -10,6 +10,7 @@ import torch from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.devices import TorchDevice if TYPE_CHECKING: from invokeai.app.services.invoker import Invoker @@ -46,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]): def load(self, name: str) -> T: file_path = self._get_path(name) try: - return torch.load(file_path) # pyright: ignore [reportUnknownMemberType] + return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType] except FileNotFoundError as e: raise ObjectNotFoundError(name) from e diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 6e518bc9e3..4fe99c31e6 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -53,25 +53,12 @@ class CacheRecord(Generic[T]): key: Unique key for each model, same as used in the models database. model: Read-only copy of the model *without weights* residing in the "meta device" - state_dict: A read-only copy of the model's state dict in RAM. It will be - used as a template for creating a copy in the VRAM. size: Size of the model - - Before a model is executed, the state_dict template is copied into VRAM, - and then injected into the model. When the model is finished, the VRAM - copy of the state dict is deleted, and the RAM version is reinjected - into the model. - - The state_dict should be treated as a read-only attribute. Do not attempt - to patch or otherwise modify it. Instead, patch the copy of the state_dict - after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` - context manager call `model_on_device()`. """ key: str size: int model: T - state_dict: Optional[Dict[str, torch.Tensor]] @dataclass diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 3a2d36e87a..817fcb2ec0 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -159,7 +159,7 @@ class ModelCache(ModelCacheBase[AnyModel]): device = free_device[0] # we are outside the lock region now - self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}") + self.logger.info(f"{current_thread} Reserved torch device {device}") # Tell TorchDevice to use this object to get the torch device. TorchDevice.set_model_cache(self) @@ -167,7 +167,7 @@ class ModelCache(ModelCacheBase[AnyModel]): yield device finally: with self._device_lock: - self.logger.info(f"Released torch device {device}") + self.logger.info(f"{current_thread} Released torch device {device}") self._execution_devices[device] = 0 self._free_execution_device.release() torch.cuda.empty_cache() @@ -215,20 +215,17 @@ class ModelCache(ModelCacheBase[AnyModel]): submodel_type: Optional[SubModelType] = None, ) -> None: """Store model under key and optional submodel_type.""" - key = self._make_cache_key(key, submodel_type) - if key in self._cached_models: - return - size = calc_model_size_by_data(model) - self.make_room(size) + with self._ram_lock: + key = self._make_cache_key(key, submodel_type) + if key in self._cached_models: + return + size = calc_model_size_by_data(model) + self.make_room(size) - if isinstance(model, torch.nn.Module): - state_dict = model.state_dict() # keep a master copy of the state dict - model = model.to(device="meta") # and keep a template in the meta device - else: - state_dict = None - cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size) - self._cached_models[key] = cache_record - self._cache_stack.append(key) + tid = threading.current_thread().ident + cache_record = CacheRecord(key=key, model=model, size=size) + self._cached_models[key] = cache_record + self._cache_stack.append(key) def get( self, @@ -296,11 +293,11 @@ class ModelCache(ModelCacheBase[AnyModel]): May raise a torch.cuda.OutOfMemoryError """ - self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}") + with self._ram_lock: + self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}") - # Some models don't have a state dictionary, in which case the - # stored model will still reside in CPU - if cache_entry.state_dict is None: + # Some models don't have a state dictionary, in which case the + # stored model will still reside in CPU if hasattr(cache_entry.model, "to"): model_in_gpu = copy.deepcopy(cache_entry.model) assert hasattr(model_in_gpu, "to") @@ -309,65 +306,6 @@ class ModelCache(ModelCacheBase[AnyModel]): else: return cache_entry.model # what happens in CPU stays in CPU - # This roundabout method for moving the model around is done to avoid - # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. - # When moving to VRAM, we copy (not move) each element of the state dict from - # RAM to a new state dict in VRAM, and then inject it into the model. - # This operation is slightly faster than running `to()` on the whole model. - # - # When the model needs to be removed from VRAM we simply delete the copy - # of the state dict in VRAM, and reinject the state dict that is cached - # in RAM into the model. So this operation is very fast. - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - try: - assert isinstance(cache_entry.model, torch.nn.Module) - template = cache_entry.model - cls = template.__class__ - with skip_torch_weight_init(): - if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"): - working_model = template.__class__.from_config(template.config) # diffusers style - else: - working_model = template.__class__(config=template.config) # transformers style (sigh) - working_model.to(device=target_device, dtype=self._precision) - working_model.load_state_dict(cache_entry.state_dict) - except Exception as e: # blow away cache entry - self._delete_cache_entry(cache_entry) - raise e - - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self.logger.info( - f"Moved model '{cache_entry.key}' to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." - f"Estimated model size: {(cache_entry.size/GIG):.3f} GB." - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self.logger.debug( - f"Moving model '{cache_entry.key}' from to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GIG):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) - return working_model - def print_cuda_stats(self) -> None: """Log CUDA diagnostics.""" vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG) @@ -445,8 +383,11 @@ class ModelCache(ModelCacheBase[AnyModel]): raise torch.cuda.OutOfMemoryError def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: - self._cache_stack.remove(cache_entry.key) - del self._cached_models[cache_entry.key] + try: + self._cache_stack.remove(cache_entry.key) + del self._cached_models[cache_entry.key] + except ValueError: + pass @staticmethod def _device_name(device: torch.device) -> str: diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 815fd41f04..fd85e2d8ad 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -31,10 +31,6 @@ class ModelLocker(ModelLockerBase): """Return the model without moving it around.""" return self._cache_entry.model - def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: - """Return the state dict (if any) for the cached model.""" - return self._cache_entry.state_dict - def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" try: @@ -56,3 +52,9 @@ class ModelLocker(ModelLockerBase): def unlock(self) -> None: """Call upon exit from context.""" self._cache.print_cuda_stats() + + # This is no longer in use in MGPU. + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + return None + diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 0f57c0efdc..b45cf91c98 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -2,7 +2,7 @@ """These classes implement model patching with LoRAs and Textual Inversions.""" from __future__ import annotations - +import threading import pickle from contextlib import contextmanager from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union @@ -34,6 +34,9 @@ with LoRAHelper.apply_lora_unet(unet, loras): # TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: + + _thread_lock = threading.Lock() + @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key @@ -106,7 +109,10 @@ class ModelPatcher: """ original_weights = {} try: - with torch.no_grad(): + with ( + torch.no_grad(), + cls._thread_lock + ): for lora, lora_weight in loras: # assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): @@ -156,9 +162,6 @@ class ModelPatcher: yield # wait for context manager exit finally: - # LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed. - # Therefore it should not be necessary to copy the original model weights back. - # This needs to be fixed before resurrecting the VRAM cache. assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() with torch.no_grad(): for module_key, weight in original_weights.items(): diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df..b0291d06fe 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -1,4 +1,5 @@ import math +import threading from dataclasses import dataclass from typing import List, Optional, Union @@ -31,9 +32,13 @@ class SDXLConditioningInfo(BasicConditioningInfo): add_time_ids: torch.Tensor def to(self, device, dtype=None): + tid = threading.current_thread().ident self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) + assert self.pooled_embeds.device == device self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) - return super().to(device=device, dtype=dtype) + result = super().to(device=device, dtype=dtype) + assert self.embeds.device == device + return result @dataclass diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49..3e3040968d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,6 +1,7 @@ from __future__ import annotations import math +import threading from typing import Any, Callable, Optional, Union import torch @@ -293,24 +294,31 @@ class InvokeAIDiffuserComponent: cross_attention_kwargs["regional_ip_data"] = regional_ip_data added_cond_kwargs = None - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.uncond_text.pooled_embeds, - conditioning_data.cond_text.pooled_embeds, - ], - dim=0, - ), - "time_ids": torch.cat( - [ - conditioning_data.uncond_text.add_time_ids, - conditioning_data.cond_text.add_time_ids, - ], - dim=0, - ), - } + try: + if conditioning_data.is_sdxl(): + #tid = threading.current_thread().ident + #print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True), + added_cond_kwargs = { + "text_embeds": torch.cat( + [ + # TODO: how to pad? just by zeros? or even truncate? + conditioning_data.uncond_text.pooled_embeds, + conditioning_data.cond_text.pooled_embeds, + ], + dim=0, + ), + "time_ids": torch.cat( + [ + conditioning_data.uncond_text.add_time_ids, + conditioning_data.cond_text.add_time_ids, + ], + dim=0, + ), + } + except Exception as e: + tid = threading.current_thread().ident + print(f'DEBUG: {tid} {str(e)}') + raise e if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings