diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py
index 87eaefc020..ceaeb95147 100644
--- a/invokeai/app/api_app.py
+++ b/invokeai/app/api_app.py
@@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
-from invokeai.backend.util.devices import get_torch_device_name
+from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies
@@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
-torch_device_name = get_torch_device_name()
+torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}")
diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py
index 481a2d2e4b..158f11a58e 100644
--- a/invokeai/app/invocations/compel.py
+++ b/invokeai/app/invocations/compel.py
@@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
SDXLConditioningInfo,
)
-from invokeai.backend.util.devices import torch_dtype
+from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField
@@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
- dtype_for_device_getter=torch_dtype,
+ dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
)
@@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
- dtype_for_device_getter=torch_dtype,
+ dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py
index ce63d568c6..a8ead96f3a 100644
--- a/invokeai/app/invocations/latent.py
+++ b/invokeai/app/invocations/latent.py
@@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
-from ...backend.util.devices import choose_precision, choose_torch_device
+from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField
-if choose_torch_device() == torch.device("mps"):
- from torch import mps
-
-DEFAULT_PRECISION = choose_precision(choose_torch_device())
+DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
@invocation_output("scheduler_output")
@@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling()
# clear memory as vae decode can request a lot
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
with torch.inference_mode():
# copied from diffusers pipeline
@@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
image_dto = context.images.save(image=image)
@@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
-
- # TODO:
- device = choose_torch_device()
+ device = TorchDevice.choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
@@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
- torch.cuda.empty_cache()
- if device == torch.device("mps"):
- mps.empty_cache()
+
+ TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name)
- # TODO:
- device = choose_torch_device()
+ device = TorchDevice.choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
@@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
- torch.cuda.empty_cache()
- if device == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation):
if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.")
- # TODO:
- device = choose_torch_device()
+ device = TorchDevice.choose_torch_device()
def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
@@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu")
- torch.cuda.empty_cache()
- if device == torch.device("mps"):
- mps.empty_cache()
+
+ TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py
index 6e5612b8b0..931e639106 100644
--- a/invokeai/app/invocations/noise.py
+++ b/invokeai/app/invocations/noise.py
@@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
-from ...backend.util.devices import choose_torch_device, torch_dtype
+from ...backend.util.devices import TorchDevice
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor,
width // downsampling_factor,
],
- dtype=torch_dtype(device),
+ dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type,
generator=generator,
).to("cpu")
@@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
@field_validator("seed", mode="before")
def modulo_seed(cls, v):
- """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
+ """Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(
width=self.width,
height=self.height,
- device=choose_torch_device(),
+ device=TorchDevice.choose_torch_device(),
seed=self.seed,
use_cpu=self.use_cpu,
)
diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py
index e09618960e..b8acfcb7bf 100644
--- a/invokeai/app/invocations/upscale.py
+++ b/invokeai/app/invocations/upscale.py
@@ -3,7 +3,6 @@ from typing import Literal
import cv2
import numpy as np
-import torch
from PIL import Image
from pydantic import ConfigDict
@@ -12,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata
@@ -33,9 +32,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
}
-if choose_torch_device() == torch.device("mps"):
- from torch import mps
-
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -115,9 +111,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
image_dto = context.images.save(image=pil_image)
diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py
index 4b5a2004be..496988e853 100644
--- a/invokeai/app/services/config/config_default.py
+++ b/invokeai/app/services/config/config_default.py
@@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
-PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
+PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
-CONFIG_SCHEMA_VERSION = "4.0.0"
+CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@@ -106,7 +106,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
- precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
+ precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
@@ -377,6 +377,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
+ # autocast was removed in v4.0.1
+ if k == "precision" and v == "autocast":
+ parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
@@ -399,6 +402,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config
+def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
+ """Migrate v4.0.0 config dictionary to a current config object.
+
+ Args:
+ config_dict: A dictionary of settings from a v4.0.0 config file.
+
+ Returns:
+ An instance of `InvokeAIAppConfig` with the migrated settings.
+ """
+ parsed_config_dict: dict[str, Any] = {}
+ for k, v in config_dict.items():
+ # autocast was removed from precision in v4.0.1
+ if k == "precision" and v == "autocast":
+ parsed_config_dict["precision"] = "auto"
+ else:
+ parsed_config_dict[k] = v
+ if k == "schema_version":
+ parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
+ config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
+ return config
+
+
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@@ -425,17 +450,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
- else:
- # Attempt to load as a v4 config file
- try:
- # Meta is not included in the model fields, so we need to validate it separately
- config = InvokeAIAppConfig.model_validate(loaded_config_dict)
- assert (
- config.schema_version == CONFIG_SCHEMA_VERSION
- ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
- return config
- except Exception as e:
- raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
+
+ if loaded_config_dict["schema_version"] == "4.0.0":
+ loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
+ loaded_config_dict.write_file(config_path)
+
+ # Attempt to load as a v4 config file
+ try:
+ # Meta is not included in the model fields, so we need to validate it separately
+ config = InvokeAIAppConfig.model_validate(loaded_config_dict)
+ assert (
+ config.schema_version == CONFIG_SCHEMA_VERSION
+ ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
+ return config
+ except Exception as e:
+ raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)
diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py
index f1fbcdb7ba..f0e25648bf 100644
--- a/invokeai/app/services/model_install/model_install_default.py
+++ b/invokeai/app/services/model_install/model_install_default.py
@@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
+import torch
import yaml
from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl
@@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger
-from invokeai.backend.util.devices import choose_precision, choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
@@ -643,11 +644,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1
return id
- @staticmethod
- def _guess_variant() -> Optional[ModelRepoVariant]:
+ def _guess_variant(self) -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download."""
- precision = choose_precision(choose_torch_device())
- return ModelRepoVariant.FP16 if precision == "float16" else None
+ precision = TorchDevice.choose_torch_dtype()
+ return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py
index de6e5f09d8..1a2b9a3402 100644
--- a/invokeai/app/services/model_manager/model_manager_default.py
+++ b/invokeai/app/services/model_manager/model_manager_default.py
@@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase."""
+from typing import Optional
+
import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig
@@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
events: EventServiceBase,
- execution_device: torch.device = choose_torch_device(),
+ execution_device: Optional[torch.device] = None,
) -> Self:
"""
Construct the model manager service instance.
@@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
logger=logger,
- execution_device=execution_device,
+ execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py
index 560d977b55..2d88c45485 100644
--- a/invokeai/backend/image_util/depth_anything/__init__.py
+++ b/invokeai/backend/image_util/depth_anything/__init__.py
@@ -12,7 +12,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
config = get_config()
@@ -47,7 +47,7 @@ class DepthAnythingDetector:
self.context = context
self.model: Optional[DPT_DINOv2] = None
self.model_size: Union[Literal["large", "base", "small"], None] = None
- self.device = choose_torch_device()
+ self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
@@ -68,7 +68,7 @@ class DepthAnythingDetector:
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval()
- self.model.to(choose_torch_device())
+ self.model.to(self.device)
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
@@ -81,7 +81,7 @@ class DepthAnythingDetector:
image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"]
- tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device())
+ tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad():
depth = self.model(tensor_image)
diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py
index 3628b0abd5..0f66af2c77 100644
--- a/invokeai/backend/image_util/dw_openpose/wholebody.py
+++ b/invokeai/backend/image_util/dw_openpose/wholebody.py
@@ -4,11 +4,10 @@
import numpy as np
import onnxruntime as ort
-import torch
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector
from .onnxpose import inference_pose
@@ -23,9 +22,9 @@ config = get_config()
class Wholebody:
def __init__(self, context: InvocationContext):
- device = choose_torch_device()
+ device = TorchDevice.choose_torch_device()
- providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"]
+ providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])
diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py
index 7c4d90f5bd..c5fe3fa598 100644
--- a/invokeai/backend/image_util/realesrgan/realesrgan.py
+++ b/invokeai/backend/image_util/realesrgan/realesrgan.py
@@ -11,7 +11,7 @@ from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
"""
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
@@ -65,7 +65,7 @@ class RealESRGAN:
self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None
self.half = half
- self.device = choose_torch_device()
+ self.device = TorchDevice.choose_torch_device()
# prefer to use params_ema
if "params_ema" in loadnet:
diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py
index 7bceae8da7..60dcd93fcc 100644
--- a/invokeai/backend/image_util/safety_checker.py
+++ b/invokeai/backend/image_util/safety_checker.py
@@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@@ -51,7 +51,7 @@ class SafetyChecker:
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
- device = choose_torch_device()
+ device = TorchDevice.choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)
diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py
index 451770c0cb..459808b455 100644
--- a/invokeai/backend/model_manager/load/load_default.py
+++ b/invokeai/backend/model_manager/load/load_default.py
@@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
-from invokeai.backend.util.devices import choose_torch_device, torch_dtype
+from invokeai.backend.util.devices import TorchDevice
# TO DO: The loader is not thread safe!
@@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
- self._torch_dtype = torch_dtype(choose_torch_device())
+ self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
index 919a7c4396..62bb766cd6 100644
--- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
+++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py
@@ -31,15 +31,12 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
-from invokeai.backend.util.devices import choose_torch_device
+from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
-if choose_torch_device() == torch.device("mps"):
- from torch import mps
-
# Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0
@@ -245,9 +242,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
+ TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device.
@@ -417,10 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.stats.cleared = models_cleared
gc.collect()
- torch.cuda.empty_cache()
- if choose_torch_device() == torch.device("mps"):
- mps.empty_cache()
-
+ TorchDevice.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py
index eb6fd45e1a..125e99be93 100644
--- a/invokeai/backend/model_manager/merge.py
+++ b/invokeai/backend/model_manager/merge.py
@@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
-from invokeai.backend.util.devices import choose_torch_device, torch_dtype
+from invokeai.backend.util.devices import TorchDevice
from . import (
AnyModelConfig,
@@ -43,6 +43,7 @@ class ModelMerger(object):
Initialize a ModelMerger object with the model installer.
"""
self._installer = installer
+ self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models(
self,
@@ -68,7 +69,7 @@ class ModelMerger(object):
warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error()
- dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
+ dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
@@ -151,7 +152,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name
- dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device())
+ dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key
diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
index b4d1b3381c..bd60b0b8c7 100644
--- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py
+++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py
@@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.attention import auto_detect_slice_size
-from invokeai.backend.util.devices import normalize_device
+from invokeai.backend.util.devices import TorchDevice
@dataclass
@@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda":
- mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device))
+ mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else:
raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8]
diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py
index 2c9cceff2c..1e4d467cd0 100644
--- a/invokeai/backend/util/__init__.py
+++ b/invokeai/backend/util/__init__.py
@@ -2,7 +2,6 @@
Initialization file for invokeai.backend.util
"""
-from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger
from .util import GIG, Chdir, directory_size
@@ -11,6 +10,4 @@ __all__ = [
"directory_size",
"Chdir",
"InvokeAILogger",
- "choose_precision",
- "choose_torch_device",
]
diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py
index cb6b93eaac..e8380dc8bc 100644
--- a/invokeai/backend/util/devices.py
+++ b/invokeai/backend/util/devices.py
@@ -1,89 +1,110 @@
-from __future__ import annotations
-
-from contextlib import nullcontext
-from typing import Literal, Optional, Union
+from typing import Dict, Literal, Optional, Union
import torch
-from torch import autocast
+from deprecated import deprecated
-from invokeai.app.services.config.config_default import PRECISION, get_config
+from invokeai.app.services.config.config_default import get_config
+# legacy APIs
+TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
+@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
+def choose_precision(device: torch.device) -> TorchPrecisionNames:
+ """Return the string representation of the recommended torch device."""
+ torch_dtype = TorchDevice.choose_torch_dtype(device)
+ return PRECISION_TO_NAME[torch_dtype]
+
+
+@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
def choose_torch_device() -> torch.device:
- """Convenience routine for guessing which GPU device to run model on"""
- config = get_config()
- if config.device == "auto":
- if torch.cuda.is_available():
- return torch.device("cuda")
- if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
- return torch.device("mps")
+ """Return the torch.device to use for accelerated inference."""
+ return TorchDevice.choose_torch_device()
+
+
+@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
+def torch_dtype(device: torch.device) -> torch.dtype:
+ """Return the torch precision for the recommended torch device."""
+ return TorchDevice.choose_torch_dtype(device)
+
+
+NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
+ "float32": torch.float32,
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+}
+PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
+
+
+class TorchDevice:
+ """Abstraction layer for torch devices."""
+
+ @classmethod
+ def choose_torch_device(cls) -> torch.device:
+ """Return the torch.device to use for accelerated inference."""
+ app_config = get_config()
+ if app_config.device != "auto":
+ device = torch.device(app_config.device)
+ elif torch.cuda.is_available():
+ device = CUDA_DEVICE
+ elif torch.backends.mps.is_available():
+ device = MPS_DEVICE
else:
- return CPU_DEVICE
- else:
- return torch.device(config.device)
+ device = CPU_DEVICE
+ return cls.normalize(device)
+ @classmethod
+ def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
+ """Return the precision to use for accelerated inference."""
+ device = device or cls.choose_torch_device()
+ config = get_config()
+ if device.type == "cuda" and torch.cuda.is_available():
+ device_name = torch.cuda.get_device_name(device)
+ if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
+ # These GPUs have limited support for float16
+ return cls._to_dtype("float32")
+ elif config.precision == "auto":
+ # Default to float16 for CUDA devices
+ return cls._to_dtype("float16")
+ else:
+ # Use the user-defined precision
+ return cls._to_dtype(config.precision)
-def get_torch_device_name() -> str:
- device = choose_torch_device()
- return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
+ elif device.type == "mps" and torch.backends.mps.is_available():
+ if config.precision == "auto":
+ # Default to float16 for MPS devices
+ return cls._to_dtype("float16")
+ else:
+ # Use the user-defined precision
+ return cls._to_dtype(config.precision)
+ # CPU / safe fallback
+ return cls._to_dtype("float32")
+ @classmethod
+ def get_torch_device_name(cls) -> str:
+ """Return the device name for the current torch device."""
+ device = cls.choose_torch_device()
+ return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
-def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]:
- """Return an appropriate precision for the given torch device."""
- app_config = get_config()
- if device.type == "cuda":
- device_name = torch.cuda.get_device_name(device)
- if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
- # These GPUs have limited support for float16
- return "float32"
- elif app_config.precision == "auto" or app_config.precision == "autocast":
- # Default to float16 for CUDA devices
- return "float16"
- else:
- # Use the user-defined precision
- return app_config.precision
- elif device.type == "mps":
- if app_config.precision == "auto" or app_config.precision == "autocast":
- # Default to float16 for MPS devices
- return "float16"
- else:
- # Use the user-defined precision
- return app_config.precision
- # CPU / safe fallback
- return "float32"
-
-
-def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype:
- device = device or choose_torch_device()
- precision = choose_precision(device)
- if precision == "float16":
- return torch.float16
- if precision == "bfloat16":
- return torch.bfloat16
- else:
- # "auto", "autocast", "float32"
- return torch.float32
-
-
-def choose_autocast(precision: PRECISION):
- """Returns an autocast context or nullcontext for the given precision string"""
- # float16 currently requires autocast to avoid errors like:
- # 'expected scalar type Half but found Float'
- if precision == "autocast" or precision == "float16":
- return autocast
- return nullcontext
-
-
-def normalize_device(device: Union[str, torch.device]) -> torch.device:
- """Ensure device has a device index defined, if appropriate."""
- device = torch.device(device)
- if device.index is None:
- # cuda might be the only torch backend that currently uses the device index?
- # I don't see anything like `current_device` for cpu or mps.
- if device.type == "cuda":
+ @classmethod
+ def normalize(cls, device: Union[str, torch.device]) -> torch.device:
+ """Add the device index to CUDA devices."""
+ device = torch.device(device)
+ if device.index is None and device.type == "cuda" and torch.cuda.is_available():
device = torch.device(device.type, torch.cuda.current_device())
- return device
+ return device
+
+ @classmethod
+ def empty_cache(cls) -> None:
+ """Clear the GPU device cache."""
+ if torch.backends.mps.is_available():
+ torch.mps.empty_cache()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ @classmethod
+ def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
+ return NAME_TO_PRECISION[precision_name]
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index 2f4e70005f..1adac4c5dc 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -770,6 +770,8 @@
"float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
+ "showEdgeLabels": "Show Edge Labels",
+ "showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection",
diff --git a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
index 372452ce99..bbb7897575 100644
--- a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
+++ b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts
@@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch();
- const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled;
+ const isModelManagerEnabled = useFeatureStatus('modelManager');
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
useHotkeys(
diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
index ad6c37532e..a67fd5f82d 100644
--- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx
@@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id);
- const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
+ const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx
index 15ab74f44d..880fdbca6c 100644
--- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx
@@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
- const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled;
+ const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const toaster = useAppToaster();
const { t } = useTranslation();
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx
index 7b1fa73472..54fb19c844 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx
@@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
const selection = useAppSelector((s) => s.gallery.selection);
const customStarUi = useStore($customStarUI);
- const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled;
+ const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
index 7f43ab3671..aff74481ca 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx
@@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const toaster = useAppToaster();
- const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled;
+ const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage();
diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts
index 05e7d075e5..f84a349d2a 100644
--- a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts
+++ b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts
@@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
[imageDTO?.image_name]
);
const isSelected = useAppSelector(selectIsSelected);
- const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled;
+ const isMultiSelectEnabled = useFeatureStatus('multiselect');
const handleClick = useCallback(
(e: MouseEvent) => {
diff --git a/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx b/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx
index 2cb96a935f..eaa6d60dda 100644
--- a/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx
+++ b/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx
@@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle';
export const HrfSettings = memo(() => {
- const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
+ const isHRFFeatureEnabled = useFeatureStatus('hrf');
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
if (!isHRFFeatureEnabled) {
diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
index 0ee06f19dd..6abc633ac8 100644
--- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
+++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx
@@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
export const useStarterModelsToast = () => {
const { t } = useTranslation();
- const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled;
+ const isEnabled = useFeatureStatus('starterModels');
const [didToast, setDidToast] = useState(false);
const [mainModels, { data }] = useMainModels();
const toast = useToast();
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx
index 0a18c2a959..77be60a945 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx
@@ -1,8 +1,9 @@
+import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow';
-import { BaseEdge, getBezierPath } from 'reactflow';
+import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector';
@@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
[source, sourceHandleId, target, targetHandleId, selected]
);
- const { isSelected, shouldAnimate, stroke } = useAppSelector(selector);
+ const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
+ const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
- const [edgePath] = getBezierPath({
+ const [edgePath, labelX, labelY] = getBezierPath({
sourceX,
sourceY,
sourcePosition,
@@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
[isSelected, shouldAnimate, stroke]
);
- return ;
+ return (
+ <>
+
+ {label && shouldShowEdgeLabels && (
+
+
+
+ {label}
+
+
+
+ )}
+ >
+ );
};
export default memo(InvocationDefaultEdge);
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts
index ba40b4984c..a485bf64c1 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts
@@ -1,7 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
-import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
+import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@@ -10,6 +10,7 @@ const defaultReturnValue = {
isSelected: false,
shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'),
+ label: '',
};
export const makeEdgeSelector = (
@@ -19,25 +20,34 @@ export const makeEdgeSelector = (
targetHandleId: string | null | undefined,
selected?: boolean
) =>
- createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => {
- const sourceNode = nodes.nodes.find((node) => node.id === source);
- const targetNode = nodes.nodes.find((node) => node.id === target);
+ createMemoizedSelector(
+ selectNodesSlice,
+ (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
+ const sourceNode = nodes.nodes.find((node) => node.id === source);
+ const targetNode = nodes.nodes.find((node) => node.id === target);
- const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
+ const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
- const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
- if (!sourceNode || !sourceHandleId) {
- return defaultReturnValue;
+ const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
+ if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
+ return defaultReturnValue;
+ }
+
+ const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
+ const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
+
+ const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
+
+ const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
+ const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
+
+ const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
+
+ return {
+ isSelected,
+ shouldAnimate: nodes.shouldAnimateEdges && isSelected,
+ stroke,
+ label,
+ };
}
-
- const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId);
- const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
-
- const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
-
- return {
- isSelected,
- shouldAnimate: nodes.shouldAnimateEdges && isSelected,
- stroke,
- };
- });
+ );
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx
index 1b93c0fdd3..c1ff625d25 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx
@@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
- const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
+ const isCacheEnabled = useFeatureStatus('invocationCache');
return (
{
- const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes;
+ const {
+ shouldAnimateEdges,
+ shouldValidateGraph,
+ shouldSnapToGrid,
+ shouldColorEdges,
+ shouldShowEdgeLabels,
+ selectionMode,
+ } = nodes;
return {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
+ shouldShowEdgeLabels,
selectionModeIsChecked: selectionMode === SelectionMode.Full,
};
});
@@ -52,8 +61,14 @@ type Props = {
const WorkflowEditorSettings = ({ children }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch();
- const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } =
- useAppSelector(selector);
+ const {
+ shouldAnimateEdges,
+ shouldValidateGraph,
+ shouldSnapToGrid,
+ shouldColorEdges,
+ shouldShowEdgeLabels,
+ selectionModeIsChecked,
+ } = useAppSelector(selector);
const handleChangeShouldValidate = useCallback(
(e: ChangeEvent) => {
@@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
[dispatch]
);
+ const handleChangeShouldShowEdgeLabels = useCallback(
+ (e: ChangeEvent) => {
+ dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
+ },
+ [dispatch]
+ );
+
const { t } = useTranslation();
return (
@@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
{t('nodes.fullyContainNodesHelp')}
+
+
+ {t('nodes.showEdgeLabels')}
+
+
+ {t('nodes.showEdgeLabelsHelp')}
+
+
{t('common.advanced')}
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts
index e16b329c22..54c092370b 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts
@@ -1,5 +1,5 @@
-import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
+import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
@@ -10,7 +10,7 @@ import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo(
() =>
- createSelector(selectNodesSlice, (nodes) => {
+ createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts
index b3676f9722..6e00d374f6 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts
@@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
export const useWithFooter = (nodeId: string) => {
const hasImageOutput = useHasImageOutput(nodeId);
- const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
-
+ const isCacheEnabled = useFeatureStatus('invocationCache');
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
return withFooter;
};
diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
index 4a1b438271..0f0417cf71 100644
--- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts
@@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
shouldAnimateEdges: true,
shouldSnapToGrid: false,
shouldColorEdges: true,
+ shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false,
nodeOpacity: 1,
selectedNodes: [],
@@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
shouldAnimateEdgesChanged: (state, action: PayloadAction) => {
state.shouldAnimateEdges = action.payload;
},
+ shouldShowEdgeLabelsChanged: (state, action: PayloadAction) => {
+ state.shouldShowEdgeLabels = action.payload;
+ },
shouldSnapToGridChanged: (state, action: PayloadAction) => {
state.shouldSnapToGrid = action.payload;
},
@@ -831,6 +835,7 @@ export const {
viewportChanged,
edgeAdded,
nodeTemplatesBuilt,
+ shouldShowEdgeLabelsChanged,
} = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched`
diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts
index 3f7d210dcb..2074f1f342 100644
--- a/invokeai/frontend/web/src/features/nodes/store/types.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/types.ts
@@ -32,6 +32,7 @@ export type NodesState = {
isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode;
+ shouldShowEdgeLabels: boolean;
};
export type WorkflowMode = 'edit' | 'view';
diff --git a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx
index 1cafe4310e..be173ca6ca 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx
@@ -1,24 +1,18 @@
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
-import { createSelector } from '@reduxjs/toolkit';
+import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
-import { memo, useCallback, useMemo } from 'react';
+import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
+const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
+
const ParamInfillColorOptions = () => {
const dispatch = useAppDispatch();
- const selector = useMemo(
- () =>
- createSelector(selectGenerationSlice, (generation) => ({
- infillColor: generation.infillColorValue,
- })),
- []
- );
-
- const { infillColor } = useAppSelector(selector);
+ const infillColor = useAppSelector(selectInfillColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
diff --git a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx
index f164bb903e..cfdd7fb010 100644
--- a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx
+++ b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx
@@ -1,35 +1,23 @@
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
-import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker';
import {
- selectGenerationSlice,
setInfillMosaicMaxColor,
setInfillMosaicMinColor,
setInfillMosaicTileHeight,
setInfillMosaicTileWidth,
} from 'features/parameters/store/generationSlice';
-import { memo, useCallback, useMemo } from 'react';
+import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next';
const ParamInfillMosaicTileSize = () => {
const dispatch = useAppDispatch();
- const selector = useMemo(
- () =>
- createSelector(selectGenerationSlice, (generation) => ({
- infillMosaicTileWidth: generation.infillMosaicTileWidth,
- infillMosaicTileHeight: generation.infillMosaicTileHeight,
- infillMosaicMinColor: generation.infillMosaicMinColor,
- infillMosaicMaxColor: generation.infillMosaicMaxColor,
- })),
- []
- );
-
- const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
- useAppSelector(selector);
-
+ const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
+ const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
+ const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
+ const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation();
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx b/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx
index 37b6318a53..101c82376c 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx
@@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const clearQueueDisclosure = useDisclosure();
- const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
- const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
+ const isPauseEnabled = useFeatureStatus('pauseQueue');
+ const isResumeEnabled = useFeatureStatus('resumeQueue');
const { queueSize } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx b/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx
index 5499ecccfc..28a12808ea 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx
@@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
import { QueueActionsMenuButton } from './QueueActionsMenuButton';
const QueueControls = () => {
- const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
+ const isPrependEnabled = useFeatureStatus('prependQueue');
return (
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx
index d7c27d54fc..2dae5e6ebe 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx
@@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
import QueueTabQueueControls from './QueueTabQueueControls';
const QueueTabContent = () => {
- const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled;
+ const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
return (
diff --git a/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx b/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx
index b42fc04e4d..3aed2f237a 100644
--- a/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx
+++ b/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx
@@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => {
- const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled;
- const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled;
+ const isPauseEnabled = useFeatureStatus('pauseQueue');
+ const isResumeEnabled = useFeatureStatus('resumeQueue');
return (
{isPauseEnabled || isResumeEnabled ? (
diff --git a/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts b/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts
index d39ac96566..f6c71dbc5a 100644
--- a/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts
+++ b/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts
@@ -13,7 +13,7 @@ export const useQueueFront = () => {
const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch',
});
- const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled;
+ const prependEnabled = useFeatureStatus('prependQueue');
const isDisabled = useMemo(() => {
return !isReady || !prependEnabled;
diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx
index 3cddc927e9..ec81b0b211 100644
--- a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx
+++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx
@@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
export const ControlSettingsAccordion: React.FC = memo(() => {
const { t } = useTranslation();
const { controlAdapterIds, badges } = useAppSelector(selector);
- const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
+ const isControlNetEnabled = useFeatureStatus('controlNet');
const { isOpen, onToggle } = useStandaloneAccordionToggle({
id: 'control-settings',
defaultIsOpen: true,
@@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
- if (isControlNetDisabled) {
+ if (!isControlNetEnabled) {
return null;
}
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx
index eceba85b5a..bee5940530 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx
@@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const language = useAppSelector((s) => s.system.language);
- const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled;
+ const isLocalizationEnabled = useFeatureStatus('localization');
const value = useMemo(() => options.find((o) => o.value === language), [language]);
diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx
index 6cf37334e7..b424a129ee 100644
--- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx
+++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx
@@ -23,9 +23,9 @@ const SettingsMenu = () => {
const { isOpen, onOpen, onClose } = useDisclosure();
useGlobalMenuClose(onClose);
- const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled;
- const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled;
- const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled;
+ const isBugLinkEnabled = useFeatureStatus('bugLink');
+ const isDiscordLinkEnabled = useFeatureStatus('discordLink');
+ const isGithubLinkEnabled = useFeatureStatus('githubLink');
return (