mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix compel conditioning object caching issue by applying deepcopy() before moving to VRAM
This commit is contained in:
parent
5d6a77d336
commit
02957be333
@ -1,6 +1,7 @@
|
|||||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import threading
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
@ -139,6 +140,7 @@ class SDXLPromptInvocationBase:
|
|||||||
lora_prefix: str,
|
lora_prefix: str,
|
||||||
zero_on_empty: bool,
|
zero_on_empty: bool,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
tid = threading.current_thread().ident
|
||||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||||
|
|
||||||
@ -205,6 +207,7 @@ class SDXLPromptInvocationBase:
|
|||||||
truncate_long_prompts=False, # TODO:
|
truncate_long_prompts=False, # TODO:
|
||||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||||
requires_pooled=get_pooled,
|
requires_pooled=get_pooled,
|
||||||
|
device=TorchDevice.choose_torch_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
conjunction = Compel.parse_prompt_string(prompt)
|
conjunction = Compel.parse_prompt_string(prompt)
|
||||||
@ -315,7 +318,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
conditioning_name = context.conditioning.save(conditioning_data)
|
conditioning_name = context.conditioning.save(conditioning_data)
|
||||||
|
|
||||||
return ConditioningOutput(
|
return ConditioningOutput(
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import threading
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -192,10 +194,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
"""Get the text embeddings and masks from the input conditioning fields."""
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
tid = threading.current_thread().ident
|
||||||
for cond in cond_list:
|
for cond in cond_list:
|
||||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
mask = cond.mask
|
mask = cond.mask
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = context.tensors.load(mask.tensor_name)
|
mask = context.tensors.load(mask.tensor_name)
|
||||||
@ -317,6 +319,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not isinstance(uncond_list, list):
|
if not isinstance(uncond_list, list):
|
||||||
uncond_list = [uncond_list]
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
|
tid = threading.current_thread().ident
|
||||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
cond_list, context, unet.device, unet.dtype
|
cond_list, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
|
@ -10,6 +10,7 @@ import torch
|
|||||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
@ -46,7 +47,7 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
|||||||
def load(self, name: str) -> T:
|
def load(self, name: str) -> T:
|
||||||
file_path = self._get_path(name)
|
file_path = self._get_path(name)
|
||||||
try:
|
try:
|
||||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
return torch.load(file_path, map_location=TorchDevice.choose_torch_device()) # pyright: ignore [reportUnknownMemberType]
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise ObjectNotFoundError(name) from e
|
raise ObjectNotFoundError(name) from e
|
||||||
|
|
||||||
|
@ -53,25 +53,12 @@ class CacheRecord(Generic[T]):
|
|||||||
|
|
||||||
key: Unique key for each model, same as used in the models database.
|
key: Unique key for each model, same as used in the models database.
|
||||||
model: Read-only copy of the model *without weights* residing in the "meta device"
|
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
|
||||||
used as a template for creating a copy in the VRAM.
|
|
||||||
size: Size of the model
|
size: Size of the model
|
||||||
|
|
||||||
Before a model is executed, the state_dict template is copied into VRAM,
|
|
||||||
and then injected into the model. When the model is finished, the VRAM
|
|
||||||
copy of the state dict is deleted, and the RAM version is reinjected
|
|
||||||
into the model.
|
|
||||||
|
|
||||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
|
||||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
|
||||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
|
||||||
context manager call `model_on_device()`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
size: int
|
size: int
|
||||||
model: T
|
model: T
|
||||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -159,7 +159,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
device = free_device[0]
|
device = free_device[0]
|
||||||
|
|
||||||
# we are outside the lock region now
|
# we are outside the lock region now
|
||||||
self.logger.info(f"Reserved torch device {device} for execution thread {current_thread}")
|
self.logger.info(f"{current_thread} Reserved torch device {device}")
|
||||||
|
|
||||||
# Tell TorchDevice to use this object to get the torch device.
|
# Tell TorchDevice to use this object to get the torch device.
|
||||||
TorchDevice.set_model_cache(self)
|
TorchDevice.set_model_cache(self)
|
||||||
@ -167,7 +167,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
yield device
|
yield device
|
||||||
finally:
|
finally:
|
||||||
with self._device_lock:
|
with self._device_lock:
|
||||||
self.logger.info(f"Released torch device {device}")
|
self.logger.info(f"{current_thread} Released torch device {device}")
|
||||||
self._execution_devices[device] = 0
|
self._execution_devices[device] = 0
|
||||||
self._free_execution_device.release()
|
self._free_execution_device.release()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -215,20 +215,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
submodel_type: Optional[SubModelType] = None,
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Store model under key and optional submodel_type."""
|
"""Store model under key and optional submodel_type."""
|
||||||
key = self._make_cache_key(key, submodel_type)
|
with self._ram_lock:
|
||||||
if key in self._cached_models:
|
key = self._make_cache_key(key, submodel_type)
|
||||||
return
|
if key in self._cached_models:
|
||||||
size = calc_model_size_by_data(model)
|
return
|
||||||
self.make_room(size)
|
size = calc_model_size_by_data(model)
|
||||||
|
self.make_room(size)
|
||||||
|
|
||||||
if isinstance(model, torch.nn.Module):
|
tid = threading.current_thread().ident
|
||||||
state_dict = model.state_dict() # keep a master copy of the state dict
|
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||||
model = model.to(device="meta") # and keep a template in the meta device
|
self._cached_models[key] = cache_record
|
||||||
else:
|
self._cache_stack.append(key)
|
||||||
state_dict = None
|
|
||||||
cache_record = CacheRecord(key=key, model=model, state_dict=state_dict, size=size)
|
|
||||||
self._cached_models[key] = cache_record
|
|
||||||
self._cache_stack.append(key)
|
|
||||||
|
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
@ -296,11 +293,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
May raise a torch.cuda.OutOfMemoryError
|
May raise a torch.cuda.OutOfMemoryError
|
||||||
"""
|
"""
|
||||||
self.logger.info(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
with self._ram_lock:
|
||||||
|
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||||
|
|
||||||
# Some models don't have a state dictionary, in which case the
|
# Some models don't have a state dictionary, in which case the
|
||||||
# stored model will still reside in CPU
|
# stored model will still reside in CPU
|
||||||
if cache_entry.state_dict is None:
|
|
||||||
if hasattr(cache_entry.model, "to"):
|
if hasattr(cache_entry.model, "to"):
|
||||||
model_in_gpu = copy.deepcopy(cache_entry.model)
|
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||||
assert hasattr(model_in_gpu, "to")
|
assert hasattr(model_in_gpu, "to")
|
||||||
@ -309,65 +306,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
return cache_entry.model # what happens in CPU stays in CPU
|
return cache_entry.model # what happens in CPU stays in CPU
|
||||||
|
|
||||||
# This roundabout method for moving the model around is done to avoid
|
|
||||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
|
||||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
|
||||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
|
||||||
# This operation is slightly faster than running `to()` on the whole model.
|
|
||||||
#
|
|
||||||
# When the model needs to be removed from VRAM we simply delete the copy
|
|
||||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
|
||||||
# in RAM into the model. So this operation is very fast.
|
|
||||||
start_model_to_time = time.time()
|
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
|
||||||
try:
|
|
||||||
assert isinstance(cache_entry.model, torch.nn.Module)
|
|
||||||
template = cache_entry.model
|
|
||||||
cls = template.__class__
|
|
||||||
with skip_torch_weight_init():
|
|
||||||
if isinstance(cls, ConfigMixin) or hasattr(cls, "from_config"):
|
|
||||||
working_model = template.__class__.from_config(template.config) # diffusers style
|
|
||||||
else:
|
|
||||||
working_model = template.__class__(config=template.config) # transformers style (sigh)
|
|
||||||
working_model.to(device=target_device, dtype=self._precision)
|
|
||||||
working_model.load_state_dict(cache_entry.state_dict)
|
|
||||||
except Exception as e: # blow away cache entry
|
|
||||||
self._delete_cache_entry(cache_entry)
|
|
||||||
raise e
|
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
|
||||||
end_model_to_time = time.time()
|
|
||||||
self.logger.info(
|
|
||||||
f"Moved model '{cache_entry.key}' to"
|
|
||||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
|
||||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
snapshot_before is not None
|
|
||||||
and snapshot_after is not None
|
|
||||||
and snapshot_before.vram is not None
|
|
||||||
and snapshot_after.vram is not None
|
|
||||||
):
|
|
||||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
|
||||||
|
|
||||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
|
||||||
if not math.isclose(
|
|
||||||
vram_change,
|
|
||||||
cache_entry.size,
|
|
||||||
rel_tol=0.1,
|
|
||||||
abs_tol=10 * MB,
|
|
||||||
):
|
|
||||||
self.logger.debug(
|
|
||||||
f"Moving model '{cache_entry.key}' from to"
|
|
||||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
|
||||||
" estimated size may be incorrect. Estimated model size:"
|
|
||||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
|
||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
|
||||||
)
|
|
||||||
return working_model
|
|
||||||
|
|
||||||
def print_cuda_stats(self) -> None:
|
def print_cuda_stats(self) -> None:
|
||||||
"""Log CUDA diagnostics."""
|
"""Log CUDA diagnostics."""
|
||||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||||
@ -445,8 +383,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
raise torch.cuda.OutOfMemoryError
|
raise torch.cuda.OutOfMemoryError
|
||||||
|
|
||||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
self._cache_stack.remove(cache_entry.key)
|
try:
|
||||||
del self._cached_models[cache_entry.key]
|
self._cache_stack.remove(cache_entry.key)
|
||||||
|
del self._cached_models[cache_entry.key]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _device_name(device: torch.device) -> str:
|
def _device_name(device: torch.device) -> str:
|
||||||
|
@ -31,10 +31,6 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Return the model without moving it around."""
|
"""Return the model without moving it around."""
|
||||||
return self._cache_entry.model
|
return self._cache_entry.model
|
||||||
|
|
||||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
|
||||||
"""Return the state dict (if any) for the cached model."""
|
|
||||||
return self._cache_entry.state_dict
|
|
||||||
|
|
||||||
def lock(self) -> AnyModel:
|
def lock(self) -> AnyModel:
|
||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
try:
|
try:
|
||||||
@ -56,3 +52,9 @@ class ModelLocker(ModelLockerBase):
|
|||||||
def unlock(self) -> None:
|
def unlock(self) -> None:
|
||||||
"""Call upon exit from context."""
|
"""Call upon exit from context."""
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
|
|
||||||
|
# This is no longer in use in MGPU.
|
||||||
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
|
"""Return the state dict (if any) for the cached model."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
"""These classes implement model patching with LoRAs and Textual Inversions."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import threading
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||||
@ -34,6 +34,9 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
|||||||
|
|
||||||
# TODO: rename smth like ModelPatcher and add TI method?
|
# TODO: rename smth like ModelPatcher and add TI method?
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
|
|
||||||
|
_thread_lock = threading.Lock()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||||
assert "." not in lora_key
|
assert "." not in lora_key
|
||||||
@ -106,7 +109,10 @@ class ModelPatcher:
|
|||||||
"""
|
"""
|
||||||
original_weights = {}
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with (
|
||||||
|
torch.no_grad(),
|
||||||
|
cls._thread_lock
|
||||||
|
):
|
||||||
for lora, lora_weight in loras:
|
for lora, lora_weight in loras:
|
||||||
# assert lora.device.type == "cpu"
|
# assert lora.device.type == "cpu"
|
||||||
for layer_key, layer in lora.layers.items():
|
for layer_key, layer in lora.layers.items():
|
||||||
@ -156,9 +162,6 @@ class ModelPatcher:
|
|||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# LS check: for now, we are not reusing models in VRAM but re-copying them each time they are needed.
|
|
||||||
# Therefore it should not be necessary to copy the original model weights back.
|
|
||||||
# This needs to be fixed before resurrecting the VRAM cache.
|
|
||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
@ -31,9 +32,13 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
add_time_ids: torch.Tensor
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
def to(self, device, dtype=None):
|
||||||
|
tid = threading.current_thread().ident
|
||||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
|
assert self.pooled_embeds.device == device
|
||||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
return super().to(device=device, dtype=dtype)
|
result = super().to(device=device, dtype=dtype)
|
||||||
|
assert self.embeds.device == device
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import threading
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -293,24 +294,31 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if conditioning_data.is_sdxl():
|
try:
|
||||||
added_cond_kwargs = {
|
if conditioning_data.is_sdxl():
|
||||||
"text_embeds": torch.cat(
|
#tid = threading.current_thread().ident
|
||||||
[
|
#print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
added_cond_kwargs = {
|
||||||
conditioning_data.uncond_text.pooled_embeds,
|
"text_embeds": torch.cat(
|
||||||
conditioning_data.cond_text.pooled_embeds,
|
[
|
||||||
],
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
dim=0,
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
),
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": torch.cat(
|
],
|
||||||
[
|
dim=0,
|
||||||
conditioning_data.uncond_text.add_time_ids,
|
),
|
||||||
conditioning_data.cond_text.add_time_ids,
|
"time_ids": torch.cat(
|
||||||
],
|
[
|
||||||
dim=0,
|
conditioning_data.uncond_text.add_time_ids,
|
||||||
),
|
conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
tid = threading.current_thread().ident
|
||||||
|
print(f'DEBUG: {tid} {str(e)}')
|
||||||
|
raise e
|
||||||
|
|
||||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
|
Loading…
Reference in New Issue
Block a user