mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fixup config_default; patch TorchDevice to work dynamically
This commit is contained in:
parent
7dd93cb810
commit
f7436f3bae
@ -5,7 +5,6 @@ from logging import Logger
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import invokeai.backend.util.devices # horrible hack
|
import invokeai.backend.util.devices # horrible hack
|
||||||
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||||
|
@ -68,16 +68,19 @@ class CompelInvocation(BaseInvocation):
|
|||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||||
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||||
|
|
||||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||||
for lora in self.clip.loras:
|
for lora in self.clip.loras:
|
||||||
lora_info = context.models.load(lora.lora)
|
lora_info = context.models.load(lora.lora)
|
||||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||||
with lora_info as model:
|
yield (lora_info.model, lora.weight)
|
||||||
yield (model, lora.weight)
|
|
||||||
del lora_info
|
del lora_info
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
@ -136,7 +139,8 @@ class SDXLPromptInvocationBase:
|
|||||||
tokenizer_model = tokenizer_info.model
|
tokenizer_model = tokenizer_info.model
|
||||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||||
assert isinstance(text_encoder_info.model, (CLIPTextModel, CLIPTextModelWithProjection))
|
text_encoder_model = text_encoder_info.model
|
||||||
|
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||||
|
|
||||||
# return zero on empty
|
# return zero on empty
|
||||||
if prompt == "" and zero_on_empty:
|
if prompt == "" and zero_on_empty:
|
||||||
|
@ -4,7 +4,7 @@ import math
|
|||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from functools import singledispatchmethod
|
from functools import singledispatchmethod
|
||||||
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
from typing import Any, Iterator, List, Literal, Optional, Tuple, Union
|
||||||
import threading
|
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
@ -525,11 +525,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
if conditioning_data.unconditioned_embeddings.embeds.device != conditioning_data.text_embeddings.embeds.device:
|
|
||||||
print(f'DEBUG; ERROR uc={conditioning_data.unconditioned_embeddings.embeds.device} c={conditioning_data.text_embeddings.embeds.device} unet={unet.device}, tid={threading.current_thread().ident}')
|
|
||||||
|
|
||||||
|
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
@ -899,6 +894,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
mask = mask.to(device=unet.device, dtype=unet.dtype)
|
||||||
if masked_latents is not None:
|
if masked_latents is not None:
|
||||||
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
scheduler = get_scheduler(
|
scheduler = get_scheduler(
|
||||||
context=context,
|
context=context,
|
||||||
scheduler_info=self.unet.scheduler,
|
scheduler_info=self.unet.scheduler,
|
||||||
|
@ -31,7 +31,7 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
|||||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||||
|
|
||||||
|
|
||||||
def get_default_ram_cache_size() -> float:
|
def get_default_ram_cache_size() -> float:
|
||||||
@ -101,9 +101,9 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
|
||||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `cuda:8`, `mps`
|
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
|
||||||
devices: List of execution devices to use in a multi-GPU environment; will override default device selected.
|
devices: List of execution devices; will override default device selected.
|
||||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||||
@ -366,9 +366,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|||||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||||
if k == "max_cache_size" and "ram" not in category_dict:
|
if k == "max_cache_size" and "ram" not in category_dict:
|
||||||
parsed_config_dict["ram"] = v
|
parsed_config_dict["ram"] = v
|
||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
# vram was removed in v4.0.2
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
if k in ["vram", "max_vram_cache_size", "lazy_offload"]:
|
||||||
parsed_config_dict["vram"] = v
|
continue
|
||||||
# autocast was removed in v4.0.1
|
# autocast was removed in v4.0.1
|
||||||
if k == "precision" and v == "autocast":
|
if k == "precision" and v == "autocast":
|
||||||
parsed_config_dict["precision"] = "auto"
|
parsed_config_dict["precision"] = "auto"
|
||||||
@ -416,6 +416,25 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||||
|
"""Migrate v4.0.1 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
if k not in ["vram", "lazy_offload"]:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
if k == "schema_version":
|
||||||
|
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||||
|
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
"""Load and migrate a config file to the latest version.
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
@ -447,6 +466,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||||
loaded_config_dict.write_file(config_path)
|
loaded_config_dict.write_file(config_path)
|
||||||
|
|
||||||
|
elif loaded_config_dict["schema_version"] == "4.0.1":
|
||||||
|
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
|
||||||
|
loaded_config_dict.write_file(config_path)
|
||||||
|
|
||||||
# Attempt to load as a v4 config file
|
# Attempt to load as a v4 config file
|
||||||
try:
|
try:
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
# Meta is not included in the model fields, so we need to validate it separately
|
||||||
|
@ -1,14 +1,11 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||||
"""Implementation of ModelManagerServiceBase."""
|
"""Implementation of ModelManagerServiceBase."""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from ..config import InvokeAIAppConfig
|
from ..config import InvokeAIAppConfig
|
||||||
@ -89,8 +86,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
max_cache_size=app_config.ram,
|
max_cache_size=app_config.ram,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
execution_devices=execution_devices,
|
execution_devices=execution_devices,
|
||||||
max_vram_cache_size=app_config.vram,
|
|
||||||
lazy_offloading=app_config.lazy_offload,
|
|
||||||
)
|
)
|
||||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||||
loader = ModelLoadService(
|
loader = ModelLoadService(
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
|
import threading
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||||
import threading
|
|
||||||
|
|
||||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||||
|
|
||||||
|
@ -187,8 +187,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
profiler.start(profile_id=session.session_id)
|
profiler.start(profile_id=session.session_id)
|
||||||
|
|
||||||
# reserve a GPU for this session - may block
|
# reserve a GPU for this session - may block
|
||||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device() as gpu:
|
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||||
|
|
||||||
# Prepare invocations and take the first
|
# Prepare invocations and take the first
|
||||||
with self._process_lock:
|
with self._process_lock:
|
||||||
invocation = session.session.next()
|
invocation = session.session.next()
|
||||||
|
@ -3,6 +3,7 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
@ -15,15 +16,24 @@ from invokeai.app.services.images.images_common import ImageDTO
|
|||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModel,
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
ModelFormat,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
)
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
"""
|
"""
|
||||||
The InvocationContext provides access to various services and data about the current invocation.
|
The InvocationContext provides access to various services and data about the current invocation.
|
||||||
@ -473,6 +483,28 @@ class UtilInterface(InvocationContextInterface):
|
|||||||
is_canceled=self.is_canceled,
|
is_canceled=self.is_canceled,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def torch_device(self) -> torch.device:
|
||||||
|
"""
|
||||||
|
Return a torch device to use in the current invocation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.device not currently in use by the system.
|
||||||
|
"""
|
||||||
|
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
|
||||||
|
return ram_cache.get_execution_device()
|
||||||
|
|
||||||
|
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Return a precision type to use with the current invocation and torch device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: Optional device.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.dtype suited for the current device.
|
||||||
|
"""
|
||||||
|
return TorchDevice.choose_torch_dtype(device)
|
||||||
|
|
||||||
|
|
||||||
class InvocationContext:
|
class InvocationContext:
|
||||||
"""Provides access to various services and data for the current invocation.
|
"""Provides access to various services and data for the current invocation.
|
||||||
|
@ -106,7 +106,7 @@ class ModelCacheBase(ABC, Generic[T]):
|
|||||||
Return an execution device that has been reserved for current thread.
|
Return an execution device that has been reserved for current thread.
|
||||||
|
|
||||||
Note that reservations are done using the current thread's TID.
|
Note that reservations are done using the current thread's TID.
|
||||||
It would be better to do this using the session ID, but that involves
|
It might be better to do this using the session ID, but that involves
|
||||||
too many detailed changes to model manager calls.
|
too many detailed changes to model manager calls.
|
||||||
|
|
||||||
May generate a ValueError if no GPU has been reserved.
|
May generate a ValueError if no GPU has been reserved.
|
||||||
|
@ -127,7 +127,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||||
if not assigned:
|
if not assigned:
|
||||||
raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
|
raise ValueError("No GPU has been reserved for the use of thread {current_thread}")
|
||||||
print(f'DEBUG: TID={current_thread}; owns {assigned[0]}')
|
|
||||||
return assigned[0]
|
return assigned[0]
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@ -157,12 +156,15 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
device = free_device[0]
|
device = free_device[0]
|
||||||
|
|
||||||
# we are outside the lock region now
|
# we are outside the lock region now
|
||||||
print(f'DEBUG: RESERVED {device} for TID {current_thread}')
|
self.logger.info("Reserved torch device {device} for execution thread {current_thread}")
|
||||||
|
|
||||||
|
# Tell TorchDevice to use this object to get the torch device.
|
||||||
|
TorchDevice.set_model_cache(self)
|
||||||
try:
|
try:
|
||||||
yield device
|
yield device
|
||||||
finally:
|
finally:
|
||||||
with self._device_lock:
|
with self._device_lock:
|
||||||
print(f'DEBUG: RELEASED {device} for TID {current_thread}')
|
self.logger.info("Released torch device {device}")
|
||||||
self._execution_devices[device] = 0
|
self._execution_devices[device] = 0
|
||||||
self._free_execution_device.release()
|
self._free_execution_device.release()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -407,12 +409,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
devices = {torch.device('mps')}
|
devices = {torch.device("mps")}
|
||||||
else:
|
else:
|
||||||
devices = {torch.device('cpu')}
|
devices = {torch.device("cpu")}
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _device_name(device: torch.device) -> str:
|
def _device_name(device: torch.device) -> str:
|
||||||
return f"{device.type}:{device.index}"
|
return f"{device.type}:{device.index}"
|
||||||
|
|
||||||
|
@ -399,11 +399,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
unet_attention_patcher = UNetAttentionPatcher(ip_adapters)
|
||||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
|
|
||||||
# NOTE error is not here!
|
|
||||||
if conditioning_data.unconditioned_embeddings.embeds.device != \
|
|
||||||
conditioning_data.text_embeddings.embeds.device:
|
|
||||||
print('DEBUG; HERE IS THE ERROR 1')
|
|
||||||
|
|
||||||
with attn_ctx:
|
with attn_ctx:
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
@ -418,10 +413,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
# print("timesteps:", timesteps)
|
# print("timesteps:", timesteps)
|
||||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||||
if conditioning_data.unconditioned_embeddings.embeds.device != \
|
|
||||||
conditioning_data.text_embeddings.embeds.device:
|
|
||||||
print('DEBUG; HERE IS THE ERROR 2')
|
|
||||||
|
|
||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
step_output = self.step(
|
step_output = self.step(
|
||||||
batched_t,
|
batched_t,
|
||||||
@ -466,7 +457,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||||
timestep = t[0]
|
timestep = t[0]
|
||||||
if additional_guidance is None:
|
if additional_guidance is None:
|
||||||
|
@ -4,7 +4,6 @@ import math
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import threading
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
@ -256,8 +255,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||||
|
|
||||||
if unconditioning.device != conditioning.device:
|
|
||||||
print(f'DEBUG: TID={threading.current_thread().ident}: Unconditioning device = {unconditioning.device}, conditioning device={conditioning.device}')
|
|
||||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
@ -1,16 +1,21 @@
|
|||||||
from typing import Dict, Literal, Optional, Union
|
from typing import TYPE_CHECKING, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
# legacy APIs
|
# legacy APIs
|
||||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||||
CPU_DEVICE = torch.device("cpu")
|
CPU_DEVICE = torch.device("cpu")
|
||||||
CUDA_DEVICE = torch.device("cuda")
|
CUDA_DEVICE = torch.device("cuda")
|
||||||
MPS_DEVICE = torch.device("mps")
|
MPS_DEVICE = torch.device("mps")
|
||||||
|
|
||||||
|
|
||||||
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
|
||||||
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
def choose_precision(device: torch.device) -> TorchPrecisionNames:
|
||||||
"""Return the string representation of the recommended torch device."""
|
"""Return the string representation of the recommended torch device."""
|
||||||
@ -41,9 +46,18 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
|||||||
class TorchDevice:
|
class TorchDevice:
|
||||||
"""Abstraction layer for torch devices."""
|
"""Abstraction layer for torch devices."""
|
||||||
|
|
||||||
|
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
||||||
|
"""Set the current model cache."""
|
||||||
|
cls._model_cache = cache
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def choose_torch_device(cls) -> torch.device:
|
def choose_torch_device(cls) -> torch.device:
|
||||||
"""Return the torch.device to use for accelerated inference."""
|
"""Return the torch.device to use for accelerated inference."""
|
||||||
|
if cls._model_cache:
|
||||||
|
return cls._model_cache.get_execution_device()
|
||||||
app_config = get_config()
|
app_config = get_config()
|
||||||
if app_config.device != "auto":
|
if app_config.device != "auto":
|
||||||
device = torch.device(app_config.device)
|
device = torch.device(app_config.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user