fix merge conflicts with main

This commit is contained in:
Lincoln Stein 2024-04-15 09:24:57 -04:00
commit 470a39935c
47 changed files with 476 additions and 277 deletions

View File

@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import get_torch_device_name from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger from ..backend.util.logging import InvokeAILogger
from .api.dependencies import ApiDependencies from .api.dependencies import ApiDependencies
@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config)
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
torch_device_name = get_torch_device_name() torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}") logger.info(f"Using torch device: {torch_device_name}")

View File

@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData, ConditioningFieldData,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from invokeai.backend.util.devices import torch_dtype from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .model import CLIPField from .model import CLIPField
@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation):
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, truncate_long_prompts=False,
) )
@ -193,7 +193,7 @@ class SDXLPromptInvocationBase:
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=ti_manager, textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False, # TODO: truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled, requires_pooled=get_pooled,

View File

@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField from .model import ModelIdentifierField, UNetField, VAEField
if choose_torch_device() == torch.device("mps"): DEFAULT_PRECISION = TorchDevice.choose_torch_dtype()
from torch import mps
DEFAULT_PRECISION = choose_precision(choose_torch_device())
@invocation_output("scheduler_output") @invocation_output("scheduler_output")
@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
name = context.tensors.save(tensor=result_latents) name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.disable_tiling() vae.disable_tiling()
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
# copied from diffusers pipeline # copied from diffusers pipeline
@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = VaeImageProcessor.numpy_to_pil(np_image)[0] image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
device = TorchDevice.choose_torch_device()
# TODO:
device = choose_torch_device()
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
latents.to(device), latents.to(device),
@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"): TorchDevice.empty_cache()
mps.empty_cache()
name = context.tensors.save(tensor=resized_latents) name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.tensors.load(self.latents.latents_name) latents = context.tensors.load(self.latents.latents_name)
# TODO: device = TorchDevice.choose_torch_device()
device = choose_torch_device()
# resizing # resizing
resized_latents = torch.nn.functional.interpolate( resized_latents = torch.nn.functional.interpolate(
@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() TorchDevice.empty_cache()
if device == torch.device("mps"):
mps.empty_cache()
name = context.tensors.save(tensor=resized_latents) name = context.tensors.save(tensor=resized_latents)
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation):
if latents_a.shape != latents_b.shape: if latents_a.shape != latents_b.shape:
raise Exception("Latents to blend must be the same size.") raise Exception("Latents to blend must be the same size.")
# TODO: device = TorchDevice.choose_torch_device()
device = choose_torch_device()
def slerp( def slerp(
t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here?
@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu") blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache()
if device == torch.device("mps"): TorchDevice.empty_cache()
mps.empty_cache()
name = context.tensors.save(tensor=blended_latents) name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents)

View File

@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX from invokeai.app.util.misc import SEED_MAX
from ...backend.util.devices import choose_torch_device, torch_dtype from ...backend.util.devices import TorchDevice
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -46,7 +46,7 @@ def get_noise(
height // downsampling_factor, height // downsampling_factor,
width // downsampling_factor, width // downsampling_factor,
], ],
dtype=torch_dtype(device), dtype=TorchDevice.choose_torch_dtype(device=device),
device=noise_device_type, device=noise_device_type,
generator=generator, generator=generator,
).to("cpu") ).to("cpu")
@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation):
@field_validator("seed", mode="before") @field_validator("seed", mode="before")
def modulo_seed(cls, v): def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" """Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1) return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> NoiseOutput: def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise( noise = get_noise(
width=self.width, width=self.width,
height=self.height, height=self.height,
device=choose_torch_device(), device=TorchDevice.choose_torch_device(),
seed=self.seed, seed=self.seed,
use_cpu=self.use_cpu, use_cpu=self.use_cpu,
) )

View File

@ -3,7 +3,6 @@ from typing import Literal
import cv2 import cv2
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from pydantic import ConfigDict from pydantic import ConfigDict
@ -12,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, invocation from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithBoard, WithMetadata from .fields import InputField, WithBoard, WithMetadata
@ -33,9 +32,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = {
"RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
} }
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2")
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
@ -115,9 +111,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)

View File

@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25 DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0 DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.0" CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float: def get_default_ram_cache_size() -> float:
@ -106,7 +106,7 @@ class InvokeAIAppConfig(BaseSettings):
lazy_offload: Keep models in VRAM until their space is needed. lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
@ -377,6 +377,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict: if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path": if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir": if k == "legacy_conf_dir":
@ -399,6 +402,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
return config return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version. """Load and migrate a config file to the latest version.
@ -425,7 +450,11 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path) migrated_config.write_file(config_path)
return migrated_config return migrated_config
else:
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file # Attempt to load as a v4 config file
try: try:
# Meta is not included in the model fields, so we need to validate it separately # Meta is not included in the model fields, so we need to validate it separately

View File

@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import torch
import yaml import yaml
from huggingface_hub import HfFolder from huggingface_hub import HfFolder
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import InvokeAILogger from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.devices import choose_precision, choose_torch_device from invokeai.backend.util.devices import TorchDevice
from .model_install_base import ( from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP, MODEL_SOURCE_TO_TYPE_MAP,
@ -643,11 +644,10 @@ class ModelInstallService(ModelInstallServiceBase):
self._next_job_id += 1 self._next_job_id += 1
return id return id
@staticmethod def _guess_variant(self) -> Optional[ModelRepoVariant]:
def _guess_variant() -> Optional[ModelRepoVariant]:
"""Guess the best HuggingFace variant type to download.""" """Guess the best HuggingFace variant type to download."""
precision = choose_precision(choose_torch_device()) precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == "float16" else None return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob( return ModelInstallJob(

View File

@ -1,12 +1,14 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
"""Implementation of ModelManagerServiceBase.""" """Implementation of ModelManagerServiceBase."""
from typing import Optional
import torch import torch
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from ..config import InvokeAIAppConfig from ..config import InvokeAIAppConfig
@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_record_service: ModelRecordServiceBase, model_record_service: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase, download_queue: DownloadQueueServiceBase,
events: EventServiceBase, events: EventServiceBase,
execution_device: torch.device = choose_torch_device(), execution_device: Optional[torch.device] = None,
) -> Self: ) -> Self:
""" """
Construct the model manager service instance. Construct the model manager service instance.
@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase):
max_vram_cache_size=app_config.vram, max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload, lazy_offloading=app_config.lazy_offload,
logger=logger, logger=logger,
execution_device=execution_device, execution_device=execution_device or TorchDevice.choose_torch_device(),
) )
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService( loader = ModelLoadService(

View File

@ -12,7 +12,7 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = get_config() config = get_config()
@ -47,7 +47,7 @@ class DepthAnythingDetector:
self.context = context self.context = context
self.model: Optional[DPT_DINOv2] = None self.model: Optional[DPT_DINOv2] = None
self.model_size: Union[Literal["large", "base", "small"], None] = None self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = choose_torch_device() self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
@ -68,7 +68,7 @@ class DepthAnythingDetector:
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval() self.model.eval()
self.model.to(choose_torch_device()) self.model.to(self.device)
return self.model return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
@ -81,7 +81,7 @@ class DepthAnythingDetector:
image_height, image_width = np_image.shape[:2] image_height, image_width = np_image.shape[:2]
np_image = transform({"image": np_image})["image"] np_image = transform({"image": np_image})["image"]
tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device()) tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device)
with torch.no_grad(): with torch.no_grad():
depth = self.model(tensor_image) depth = self.model(tensor_image)

View File

@ -4,11 +4,10 @@
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
import torch
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from .onnxdet import inference_detector from .onnxdet import inference_detector
from .onnxpose import inference_pose from .onnxpose import inference_pose
@ -23,9 +22,9 @@ config = get_config()
class Wholebody: class Wholebody:
def __init__(self, context: InvocationContext): def __init__(self, context: InvocationContext):
device = choose_torch_device() device = TorchDevice.choose_torch_device()
providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"] providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"]
onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"])
onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"])

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
""" """
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
@ -65,7 +65,7 @@ class RealESRGAN:
self.pre_pad = pre_pad self.pre_pad = pre_pad
self.mod_scale: Optional[int] = None self.mod_scale: Optional[int] = None
self.half = half self.half = half
self.device = choose_torch_device() self.device = TorchDevice.choose_torch_device()
# prefer to use params_ema # prefer to use params_ema
if "params_ema" in loadnet: if "params_ema" in loadnet:

View File

@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -51,7 +51,7 @@ class SafetyChecker:
cls._load_safety_checker() cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None: if cls.safety_checker is None or cls.feature_extractor is None:
return False return False
device = choose_torch_device() device = TorchDevice.choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt") features = cls.feature_extractor([image], return_tensors="pt")
features.to(device) features.to(device)
cls.safety_checker.to(device) cls.safety_checker.to(device)

View File

@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import TorchDevice
# TO DO: The loader is not thread safe! # TO DO: The loader is not thread safe!
@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger self._logger = logger
self._ram_cache = ram_cache self._ram_cache = ram_cache
self._convert_cache = convert_cache self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device()) self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
""" """

View File

@ -31,15 +31,12 @@ import torch
from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0 DEFAULT_MAX_CACHE_SIZE = 6.0
@ -245,9 +242,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
) )
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device. """Move model into the indicated device.
@ -417,10 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.stats.cleared = models_cleared self.stats.cleared = models_cleared
gc.collect() gc.collect()
torch.cuda.empty_cache() TorchDevice.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:

View File

@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging
from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.util.devices import choose_torch_device, torch_dtype from invokeai.backend.util.devices import TorchDevice
from . import ( from . import (
AnyModelConfig, AnyModelConfig,
@ -43,6 +43,7 @@ class ModelMerger(object):
Initialize a ModelMerger object with the model installer. Initialize a ModelMerger object with the model installer.
""" """
self._installer = installer self._installer = installer
self._dtype = TorchDevice.choose_torch_dtype()
def merge_diffusion_models( def merge_diffusion_models(
self, self,
@ -68,7 +69,7 @@ class ModelMerger(object):
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) dtype = torch.float16 if variant == "fp16" else self._dtype
# Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models # Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models
# until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released. # until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released.
@ -151,7 +152,7 @@ class ModelMerger(object):
dump_path.mkdir(parents=True, exist_ok=True) dump_path.mkdir(parents=True, exist_ok=True)
dump_path = dump_path / merged_model_name dump_path = dump_path / merged_model_name
dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) dtype = torch.float16 if variant == "fp16" else self._dtype
merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant) merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant)
# register model and get its unique key # register model and get its unique key

View File

@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.attention import auto_detect_slice_size from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import normalize_device from invokeai.backend.util.devices import TorchDevice
@dataclass @dataclass
@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if self.unet.device.type == "cpu" or self.unet.device.type == "mps": if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free mem_free = psutil.virtual_memory().free
elif self.unet.device.type == "cuda": elif self.unet.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device)) mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device))
else: else:
raise ValueError(f"unrecognized device {self.unet.device}") raise ValueError(f"unrecognized device {self.unet.device}")
# input tensor of [1, 4, h/8, w/8] # input tensor of [1, 4, h/8, w/8]

View File

@ -2,7 +2,6 @@
Initialization file for invokeai.backend.util Initialization file for invokeai.backend.util
""" """
from .devices import choose_precision, choose_torch_device
from .logging import InvokeAILogger from .logging import InvokeAILogger
from .util import GIG, Chdir, directory_size from .util import GIG, Chdir, directory_size
@ -11,6 +10,4 @@ __all__ = [
"directory_size", "directory_size",
"Chdir", "Chdir",
"InvokeAILogger", "InvokeAILogger",
"choose_precision",
"choose_torch_device",
] ]

View File

@ -1,89 +1,110 @@
from __future__ import annotations from typing import Dict, Literal, Optional, Union
from contextlib import nullcontext
from typing import Literal, Optional, Union
import torch import torch
from torch import autocast from deprecated import deprecated
from invokeai.app.services.config.config_default import PRECISION, get_config from invokeai.app.services.config.config_default import get_config
# legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda") CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps") MPS_DEVICE = torch.device("mps")
@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
def choose_precision(device: torch.device) -> TorchPrecisionNames:
"""Return the string representation of the recommended torch device."""
torch_dtype = TorchDevice.choose_torch_dtype(device)
return PRECISION_TO_NAME[torch_dtype]
@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore
def choose_torch_device() -> torch.device: def choose_torch_device() -> torch.device:
"""Convenience routine for guessing which GPU device to run model on""" """Return the torch.device to use for accelerated inference."""
config = get_config() return TorchDevice.choose_torch_device()
if config.device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return CPU_DEVICE
else:
return torch.device(config.device)
def get_torch_device_name() -> str: @deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore
device = choose_torch_device() def torch_dtype(device: torch.device) -> torch.dtype:
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() """Return the torch precision for the recommended torch device."""
return TorchDevice.choose_torch_dtype(device)
def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]: NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = {
"""Return an appropriate precision for the given torch device.""" "float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()}
class TorchDevice:
"""Abstraction layer for torch devices."""
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
app_config = get_config() app_config = get_config()
if device.type == "cuda": if app_config.device != "auto":
device = torch.device(app_config.device)
elif torch.cuda.is_available():
device = CUDA_DEVICE
elif torch.backends.mps.is_available():
device = MPS_DEVICE
else:
device = CPU_DEVICE
return cls.normalize(device)
@classmethod
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
"""Return the precision to use for accelerated inference."""
device = device or cls.choose_torch_device()
config = get_config()
if device.type == "cuda" and torch.cuda.is_available():
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
# These GPUs have limited support for float16 # These GPUs have limited support for float16
return "float32" return cls._to_dtype("float32")
elif app_config.precision == "auto" or app_config.precision == "autocast": elif config.precision == "auto":
# Default to float16 for CUDA devices # Default to float16 for CUDA devices
return "float16" return cls._to_dtype("float16")
else: else:
# Use the user-defined precision # Use the user-defined precision
return app_config.precision return cls._to_dtype(config.precision)
elif device.type == "mps":
if app_config.precision == "auto" or app_config.precision == "autocast": elif device.type == "mps" and torch.backends.mps.is_available():
if config.precision == "auto":
# Default to float16 for MPS devices # Default to float16 for MPS devices
return "float16" return cls._to_dtype("float16")
else: else:
# Use the user-defined precision # Use the user-defined precision
return app_config.precision return cls._to_dtype(config.precision)
# CPU / safe fallback # CPU / safe fallback
return "float32" return cls._to_dtype("float32")
@classmethod
def get_torch_device_name(cls) -> str:
"""Return the device name for the current torch device."""
device = cls.choose_torch_device()
return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper()
def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: @classmethod
device = device or choose_torch_device() def normalize(cls, device: Union[str, torch.device]) -> torch.device:
precision = choose_precision(device) """Add the device index to CUDA devices."""
if precision == "float16":
return torch.float16
if precision == "bfloat16":
return torch.bfloat16
else:
# "auto", "autocast", "float32"
return torch.float32
def choose_autocast(precision: PRECISION):
"""Returns an autocast context or nullcontext for the given precision string"""
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
if precision == "autocast" or precision == "float16":
return autocast
return nullcontext
def normalize_device(device: Union[str, torch.device]) -> torch.device:
"""Ensure device has a device index defined, if appropriate."""
device = torch.device(device) device = torch.device(device)
if device.index is None: if device.index is None and device.type == "cuda" and torch.cuda.is_available():
# cuda might be the only torch backend that currently uses the device index?
# I don't see anything like `current_device` for cpu or mps.
if device.type == "cuda":
device = torch.device(device.type, torch.cuda.current_device()) device = torch.device(device.type, torch.cuda.current_device())
return device return device
@classmethod
def empty_cache(cls) -> None:
"""Clear the GPU device cache."""
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]

View File

@ -770,6 +770,8 @@
"float": "Float", "float": "Float",
"fullyContainNodes": "Fully Contain Nodes to Select", "fullyContainNodes": "Fully Contain Nodes to Select",
"fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected",
"showEdgeLabels": "Show Edge Labels",
"showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes",
"hideLegendNodes": "Hide Field Type Legend", "hideLegendNodes": "Hide Field Type Legend",
"hideMinimapnodes": "Hide MiniMap", "hideMinimapnodes": "Hide MiniMap",
"inputMayOnlyHaveOneConnection": "Input may only have one connection", "inputMayOnlyHaveOneConnection": "Input may only have one connection",

View File

@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
export const useGlobalHotkeys = () => { export const useGlobalHotkeys = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled; const isModelManagerEnabled = useFeatureStatus('modelManager');
const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack(); const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack();
useHotkeys( useHotkeys(

View File

@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd); const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board_id); const boardName = useBoardName(board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled; const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation(); const [bulkDownload] = useBulkDownloadImagesMutation();

View File

@ -54,7 +54,7 @@ const CurrentImageButtons = () => {
const selection = useAppSelector((s) => s.gallery.selection); const selection = useAppSelector((s) => s.gallery.selection);
const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons);
const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled; const isUpscalingEnabled = useFeatureStatus('upscaling');
const isQueueMutationInProgress = useIsQueueMutationInProgress(); const isQueueMutationInProgress = useIsQueueMutationInProgress();
const toaster = useAppToaster(); const toaster = useAppToaster();
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => {
const selection = useAppSelector((s) => s.gallery.selection); const selection = useAppSelector((s) => s.gallery.selection);
const customStarUi = useStore($customStarUI); const customStarUi = useStore($customStarUI);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled; const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [starImages] = useStarImagesMutation(); const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation(); const [unstarImages] = useUnstarImagesMutation();

View File

@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const toaster = useAppToaster(); const toaster = useAppToaster();
const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; const isCanvasEnabled = useFeatureStatus('unifiedCanvas');
const customStarUi = useStore($customStarUI); const customStarUi = useStore($customStarUI);
const { downloadImage } = useDownloadImage(); const { downloadImage } = useDownloadImage();

View File

@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
[imageDTO?.image_name] [imageDTO?.image_name]
); );
const isSelected = useAppSelector(selectIsSelected); const isSelected = useAppSelector(selectIsSelected);
const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled; const isMultiSelectEnabled = useFeatureStatus('multiselect');
const handleClick = useCallback( const handleClick = useCallback(
(e: MouseEvent<HTMLDivElement>) => { (e: MouseEvent<HTMLDivElement>) => {

View File

@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle'; import ParamHrfToggle from './ParamHrfToggle';
export const HrfSettings = memo(() => { export const HrfSettings = memo(() => {
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled; const isHRFFeatureEnabled = useFeatureStatus('hrf');
const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled); const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled);
if (!isHRFFeatureEnabled) { if (!isHRFFeatureEnabled) {

View File

@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels';
export const useStarterModelsToast = () => { export const useStarterModelsToast = () => {
const { t } = useTranslation(); const { t } = useTranslation();
const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled; const isEnabled = useFeatureStatus('starterModels');
const [didToast, setDidToast] = useState(false); const [didToast, setDidToast] = useState(false);
const [mainModels, { data }] = useMainModels(); const [mainModels, { data }] = useMainModels();
const toast = useToast(); const toast = useToast();

View File

@ -1,8 +1,9 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { CSSProperties } from 'react'; import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow'; import type { EdgeProps } from 'reactflow';
import { BaseEdge, getBezierPath } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
import { makeEdgeSelector } from './util/makeEdgeSelector'; import { makeEdgeSelector } from './util/makeEdgeSelector';
@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({
[source, sourceHandleId, target, targetHandleId, selected] [source, sourceHandleId, target, targetHandleId, selected]
); );
const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);
const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels);
const [edgePath] = getBezierPath({ const [edgePath, labelX, labelY] = getBezierPath({
sourceX, sourceX,
sourceY, sourceY,
sourcePosition, sourcePosition,
@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({
[isSelected, shouldAnimate, stroke] [isSelected, shouldAnimate, stroke]
); );
return <BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />; return (
<>
<BaseEdge path={edgePath} markerEnd={markerEnd} style={edgeStyles} />
{label && shouldShowEdgeLabels && (
<EdgeLabelRenderer>
<Flex
className="nodrag nopan"
pointerEvents="all"
position="absolute"
transform={`translate(-50%, -50%) translate(${labelX}px,${labelY}px)`}
bg="base.800"
borderRadius="base"
borderWidth={1}
borderColor={isSelected ? 'undefined' : 'transparent'}
opacity={isSelected ? 1 : 0.5}
py={1}
px={3}
shadow="md"
>
<Text size="sm" fontWeight="semibold" color={isSelected ? 'base.100' : 'base.300'}>
{label}
</Text>
</Flex>
</EdgeLabelRenderer>
)}
</>
);
}; };
export default memo(InvocationDefaultEdge); export default memo(InvocationDefaultEdge);

View File

@ -1,7 +1,7 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor'; import { getFieldColor } from './getEdgeColor';
@ -10,6 +10,7 @@ const defaultReturnValue = {
isSelected: false, isSelected: false,
shouldAnimate: false, shouldAnimate: false,
stroke: colorTokenToCssVar('base.500'), stroke: colorTokenToCssVar('base.500'),
label: '',
}; };
export const makeEdgeSelector = ( export const makeEdgeSelector = (
@ -19,14 +20,16 @@ export const makeEdgeSelector = (
targetHandleId: string | null | undefined, targetHandleId: string | null | undefined,
selected?: boolean selected?: boolean
) => ) =>
createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { createMemoizedSelector(
selectNodesSlice,
(nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => {
const sourceNode = nodes.nodes.find((node) => node.id === source); const sourceNode = nodes.nodes.find((node) => node.id === source);
const targetNode = nodes.nodes.find((node) => node.id === target); const targetNode = nodes.nodes.find((node) => node.id === target);
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected);
if (!sourceNode || !sourceHandleId) { if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
return defaultReturnValue; return defaultReturnValue;
} }
@ -35,9 +38,16 @@ export const makeEdgeSelector = (
const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return { return {
isSelected, isSelected,
shouldAnimate: nodes.shouldAnimateEdges && isSelected, shouldAnimate: nodes.shouldAnimateEdges && isSelected,
stroke, stroke,
label,
}; };
}); }
);

View File

@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' };
const InvocationNodeFooter = ({ nodeId }: Props) => { const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId); const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; const isCacheEnabled = useFeatureStatus('invocationCache');
return ( return (
<Flex <Flex
className={DRAG_HANDLE_CLASSNAME} className={DRAG_HANDLE_CLASSNAME}

View File

@ -24,6 +24,7 @@ import {
selectNodesSlice, selectNodesSlice,
shouldAnimateEdgesChanged, shouldAnimateEdgesChanged,
shouldColorEdgesChanged, shouldColorEdgesChanged,
shouldShowEdgeLabelsChanged,
shouldSnapToGridChanged, shouldSnapToGridChanged,
shouldValidateGraphChanged, shouldValidateGraphChanged,
} from 'features/nodes/store/nodesSlice'; } from 'features/nodes/store/nodesSlice';
@ -35,12 +36,20 @@ import { SelectionMode } from 'reactflow';
const formLabelProps: FormLabelProps = { flexGrow: 1 }; const formLabelProps: FormLabelProps = { flexGrow: 1 };
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes; const {
shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionMode,
} = nodes;
return { return {
shouldAnimateEdges, shouldAnimateEdges,
shouldValidateGraph, shouldValidateGraph,
shouldSnapToGrid, shouldSnapToGrid,
shouldColorEdges, shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked: selectionMode === SelectionMode.Full, selectionModeIsChecked: selectionMode === SelectionMode.Full,
}; };
}); });
@ -52,8 +61,14 @@ type Props = {
const WorkflowEditorSettings = ({ children }: Props) => { const WorkflowEditorSettings = ({ children }: Props) => {
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } = const {
useAppSelector(selector); shouldAnimateEdges,
shouldValidateGraph,
shouldSnapToGrid,
shouldColorEdges,
shouldShowEdgeLabels,
selectionModeIsChecked,
} = useAppSelector(selector);
const handleChangeShouldValidate = useCallback( const handleChangeShouldValidate = useCallback(
(e: ChangeEvent<HTMLInputElement>) => { (e: ChangeEvent<HTMLInputElement>) => {
@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => {
[dispatch] [dispatch]
); );
const handleChangeShouldShowEdgeLabels = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowEdgeLabelsChanged(e.target.checked));
},
[dispatch]
);
const { t } = useTranslation(); const { t } = useTranslation();
return ( return (
@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => {
<FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText> <FormHelperText>{t('nodes.fullyContainNodesHelp')}</FormHelperText>
</FormControl> </FormControl>
<Divider /> <Divider />
<FormControl>
<Flex w="full">
<FormLabel>{t('nodes.showEdgeLabels')}</FormLabel>
<Switch isChecked={shouldShowEdgeLabels} onChange={handleChangeShouldShowEdgeLabels} />
</Flex>
<FormHelperText>{t('nodes.showEdgeLabelsHelp')}</FormHelperText>
</FormControl>
<Divider />
<Heading size="sm" pt={4}> <Heading size="sm" pt={4}>
{t('common.advanced')} {t('common.advanced')}
</Heading> </Heading>

View File

@ -1,5 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants'; import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { selectNodeTemplate } from 'features/nodes/store/selectors';
@ -10,7 +10,7 @@ import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => { export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createSelector(selectNodesSlice, (nodes) => { createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId); const template = selectNodeTemplate(nodes, nodeId);
if (!template) { if (!template) {
return EMPTY_ARRAY; return EMPTY_ARRAY;

View File

@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput';
export const useWithFooter = (nodeId: string) => { export const useWithFooter = (nodeId: string) => {
const hasImageOutput = useHasImageOutput(nodeId); const hasImageOutput = useHasImageOutput(nodeId);
const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; const isCacheEnabled = useFeatureStatus('invocationCache');
const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]); const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]);
return withFooter; return withFooter;
}; };

View File

@ -103,6 +103,7 @@ const initialNodesState: NodesState = {
shouldAnimateEdges: true, shouldAnimateEdges: true,
shouldSnapToGrid: false, shouldSnapToGrid: false,
shouldColorEdges: true, shouldColorEdges: true,
shouldShowEdgeLabels: false,
isAddNodePopoverOpen: false, isAddNodePopoverOpen: false,
nodeOpacity: 1, nodeOpacity: 1,
selectedNodes: [], selectedNodes: [],
@ -549,6 +550,9 @@ export const nodesSlice = createSlice({
shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => { shouldAnimateEdgesChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAnimateEdges = action.payload; state.shouldAnimateEdges = action.payload;
}, },
shouldShowEdgeLabelsChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowEdgeLabels = action.payload;
},
shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => { shouldSnapToGridChanged: (state, action: PayloadAction<boolean>) => {
state.shouldSnapToGrid = action.payload; state.shouldSnapToGrid = action.payload;
}, },
@ -831,6 +835,7 @@ export const {
viewportChanged, viewportChanged,
edgeAdded, edgeAdded,
nodeTemplatesBuilt, nodeTemplatesBuilt,
shouldShowEdgeLabelsChanged,
} = nodesSlice.actions; } = nodesSlice.actions;
// This is used for tracking `state.workflow.isTouched` // This is used for tracking `state.workflow.isTouched`

View File

@ -32,6 +32,7 @@ export type NodesState = {
isAddNodePopoverOpen: boolean; isAddNodePopoverOpen: boolean;
addNewNodePosition: XYPosition | null; addNewNodePosition: XYPosition | null;
selectionMode: SelectionMode; selectionMode: SelectionMode;
shouldShowEdgeLabels: boolean;
}; };
export type WorkflowMode = 'edit' | 'view'; export type WorkflowMode = 'edit' | 'view';

View File

@ -1,24 +1,18 @@
import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker'; import IAIColorPicker from 'common/components/IAIColorPicker';
import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful'; import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue);
const ParamInfillColorOptions = () => { const ParamInfillColorOptions = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo( const infillColor = useAppSelector(selectInfillColor);
() =>
createSelector(selectGenerationSlice, (generation) => ({
infillColor: generation.infillColorValue,
})),
[]
);
const { infillColor } = useAppSelector(selector);
const infillMethod = useAppSelector((s) => s.generation.infillMethod); const infillMethod = useAppSelector((s) => s.generation.infillMethod);

View File

@ -1,35 +1,23 @@
import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIColorPicker from 'common/components/IAIColorPicker'; import IAIColorPicker from 'common/components/IAIColorPicker';
import { import {
selectGenerationSlice,
setInfillMosaicMaxColor, setInfillMosaicMaxColor,
setInfillMosaicMinColor, setInfillMosaicMinColor,
setInfillMosaicTileHeight, setInfillMosaicTileHeight,
setInfillMosaicTileWidth, setInfillMosaicTileWidth,
} from 'features/parameters/store/generationSlice'; } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback } from 'react';
import type { RgbaColor } from 'react-colorful'; import type { RgbaColor } from 'react-colorful';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const ParamInfillMosaicTileSize = () => { const ParamInfillMosaicTileSize = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const selector = useMemo( const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth);
() => const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight);
createSelector(selectGenerationSlice, (generation) => ({ const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor);
infillMosaicTileWidth: generation.infillMosaicTileWidth, const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor);
infillMosaicTileHeight: generation.infillMosaicTileHeight,
infillMosaicMinColor: generation.infillMosaicMinColor,
infillMosaicMaxColor: generation.infillMosaicMaxColor,
})),
[]
);
const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } =
useAppSelector(selector);
const infillMethod = useAppSelector((s) => s.generation.infillMethod); const infillMethod = useAppSelector((s) => s.generation.infillMethod);
const { t } = useTranslation(); const { t } = useTranslation();

View File

@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { t } = useTranslation(); const { t } = useTranslation();
const clearQueueDisclosure = useDisclosure(); const clearQueueDisclosure = useDisclosure();
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled; const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled; const isResumeEnabled = useFeatureStatus('resumeQueue');
const { queueSize } = useGetQueueStatusQuery(undefined, { const { queueSize } = useGetQueueStatusQuery(undefined, {
selectFromResult: (res) => ({ selectFromResult: (res) => ({
queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0, queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0,

View File

@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton';
import { QueueActionsMenuButton } from './QueueActionsMenuButton'; import { QueueActionsMenuButton } from './QueueActionsMenuButton';
const QueueControls = () => { const QueueControls = () => {
const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled; const isPrependEnabled = useFeatureStatus('prependQueue');
return ( return (
<Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column"> <Flex w="full" position="relative" borderRadius="base" gap={2} pt={2} flexDir="column">
<ButtonGroup size="lg" isAttached={false}> <ButtonGroup size="lg" isAttached={false}>

View File

@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus';
import QueueTabQueueControls from './QueueTabQueueControls'; import QueueTabQueueControls from './QueueTabQueueControls';
const QueueTabContent = () => { const QueueTabContent = () => {
const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; const isInvocationCacheEnabled = useFeatureStatus('invocationCache');
return ( return (
<Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}> <Flex borderRadius="base" w="full" h="full" flexDir="column" gap={2}>

View File

@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton';
import ResumeProcessorButton from './ResumeProcessorButton'; import ResumeProcessorButton from './ResumeProcessorButton';
const QueueTabQueueControls = () => { const QueueTabQueueControls = () => {
const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled; const isPauseEnabled = useFeatureStatus('pauseQueue');
const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled; const isResumeEnabled = useFeatureStatus('resumeQueue');
return ( return (
<Flex layerStyle="first" borderRadius="base" p={2} gap={2}> <Flex layerStyle="first" borderRadius="base" p={2} gap={2}>
{isPauseEnabled || isResumeEnabled ? ( {isPauseEnabled || isResumeEnabled ? (

View File

@ -13,7 +13,7 @@ export const useQueueFront = () => {
const [_, { isLoading }] = useEnqueueBatchMutation({ const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch', fixedCacheKey: 'enqueueBatch',
}); });
const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled; const prependEnabled = useFeatureStatus('prependQueue');
const isDisabled = useMemo(() => { const isDisabled = useMemo(() => {
return !isReady || !prependEnabled; return !isReady || !prependEnabled;

View File

@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap
export const ControlSettingsAccordion: React.FC = memo(() => { export const ControlSettingsAccordion: React.FC = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const { controlAdapterIds, badges } = useAppSelector(selector); const { controlAdapterIds, badges } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; const isControlNetEnabled = useFeatureStatus('controlNet');
const { isOpen, onToggle } = useStandaloneAccordionToggle({ const { isOpen, onToggle } = useStandaloneAccordionToggle({
id: 'control-settings', id: 'control-settings',
defaultIsOpen: true, defaultIsOpen: true,
@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => {
const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter'); const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter');
const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter'); const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter');
if (isControlNetDisabled) { if (!isControlNetEnabled) {
return null; return null;
} }

View File

@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const language = useAppSelector((s) => s.system.language); const language = useAppSelector((s) => s.system.language);
const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; const isLocalizationEnabled = useFeatureStatus('localization');
const value = useMemo(() => options.find((o) => o.value === language), [language]); const value = useMemo(() => options.find((o) => o.value === language), [language]);

View File

@ -23,9 +23,9 @@ const SettingsMenu = () => {
const { isOpen, onOpen, onClose } = useDisclosure(); const { isOpen, onOpen, onClose } = useDisclosure();
useGlobalMenuClose(onClose); useGlobalMenuClose(onClose);
const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; const isBugLinkEnabled = useFeatureStatus('bugLink');
const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled; const isDiscordLinkEnabled = useFeatureStatus('discordLink');
const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled; const isGithubLinkEnabled = useFeatureStatus('githubLink');
return ( return (
<Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}> <Menu isOpen={isOpen} onOpen={onOpen} onClose={onClose}>

View File

@ -1,32 +1,24 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import type { AppFeature, SDFeature } from 'app/types/invokeai'; import type { AppFeature, SDFeature } from 'app/types/invokeai';
import { selectConfigSlice } from 'features/system/store/configSlice';
import type { InvokeTabName } from 'features/ui/store/tabMap'; import type { InvokeTabName } from 'features/ui/store/tabMap';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => { export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => {
const disabledTabs = useAppSelector((s) => s.config.disabledTabs); const selectIsFeatureEnabled = useMemo(
const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures);
const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures);
const isFeatureDisabled = useMemo(
() => () =>
disabledFeatures.includes(feature as AppFeature) || createSelector(selectConfigSlice, (config) => {
disabledSDFeatures.includes(feature as SDFeature) || return !(
disabledTabs.includes(feature as InvokeTabName), config.disabledFeatures.includes(feature as AppFeature) ||
[disabledFeatures, disabledSDFeatures, disabledTabs, feature] config.disabledSDFeatures.includes(feature as SDFeature) ||
config.disabledTabs.includes(feature as InvokeTabName)
);
}),
[feature]
); );
const isFeatureEnabled = useMemo( const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled);
() =>
!(
disabledFeatures.includes(feature as AppFeature) ||
disabledSDFeatures.includes(feature as SDFeature) ||
disabledTabs.includes(feature as InvokeTabName)
),
[disabledFeatures, disabledSDFeatures, disabledTabs, feature]
);
return { isFeatureDisabled, isFeatureEnabled }; return isFeatureEnabled;
}; };

View File

@ -0,0 +1,132 @@
"""
Test abstract device class.
"""
from unittest.mock import patch
import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
@pytest.mark.parametrize("device_name", devices)
def test_device_choice(device_name):
config = get_config()
config.device = device_name
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_cuda(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
def test_device_dtype_mps(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=True),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_override(device_dtype_pair):
with (
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
config.precision = "float32"
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == torch.float32
def test_normalize():
assert (
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
)
assert TorchDevice.normalize("mps") == torch.device("mps")
assert TorchDevice.normalize("cpu") == torch.device("cpu")
@pytest.mark.parametrize("device_name", devices)
def test_legacy_device_choice(device_name):
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_legacy_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
returned_dtype = torch_dtype(torch_device)
assert returned_dtype == dtype
def test_legacy_precision_name():
config = get_config()
config.precision = "auto"
with (
pytest.deprecated_call(),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))