diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 87eaefc020..ceaeb95147 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.backend.util.devices import get_torch_device_name +from invokeai.backend.util.devices import TorchDevice from ..backend.util.logging import InvokeAILogger from .api.dependencies import ApiDependencies @@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config) mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("text/css", ".css") -torch_device_name = get_torch_device_name() +torch_device_name = TorchDevice.get_torch_device_name() logger.info(f"Using torch device: {torch_device_name}") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 481a2d2e4b..158f11a58e 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ConditioningFieldData, SDXLConditioningInfo, ) -from invokeai.backend.util.devices import torch_dtype +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import CLIPField @@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation): tokenizer=tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, + dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, ) @@ -193,7 +193,7 @@ class SDXLPromptInvocationBase: tokenizer=tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, + dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ce63d568c6..a8ead96f3a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( image_resized_to_grid_as_tensor, ) from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import choose_precision, choose_torch_device +from ...backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .controlnet_image_processors import ControlField from .model import ModelIdentifierField, UNetField, VAEField -if choose_torch_device() == torch.device("mps"): - from torch import mps - -DEFAULT_PRECISION = choose_precision(choose_torch_device()) +DEFAULT_PRECISION = TorchDevice.choose_torch_dtype() @invocation_output("scheduler_output") @@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) @@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.disable_tiling() # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() with torch.inference_mode(): # copied from diffusers pipeline @@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): image = VaeImageProcessor.numpy_to_pil(np_image)[0] - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() image_dto = context.images.save(image=image) @@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.tensors.load(self.latents.latents_name) - - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() resized_latents = torch.nn.functional.interpolate( latents.to(device), @@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + + TorchDevice.empty_cache() name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.tensors.load(self.latents.latents_name) - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() # resizing resized_latents = torch.nn.functional.interpolate( @@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation): if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() def slerp( t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? @@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + + TorchDevice.empty_cache() name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 6e5612b8b0..931e639106 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.util.devices import TorchDevice from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -46,7 +46,7 @@ def get_noise( height // downsampling_factor, width // downsampling_factor, ], - dtype=torch_dtype(device), + dtype=TorchDevice.choose_torch_dtype(device=device), device=noise_device_type, generator=generator, ).to("cpu") @@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation): @field_validator("seed", mode="before") def modulo_seed(cls, v): - """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" + """Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, - device=choose_torch_device(), + device=TorchDevice.choose_torch_device(), seed=self.seed, use_cpu=self.use_cpu, ) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index e09618960e..b8acfcb7bf 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -3,7 +3,6 @@ from typing import Literal import cv2 import numpy as np -import torch from PIL import Image from pydantic import ConfigDict @@ -12,7 +11,7 @@ from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithBoard, WithMetadata @@ -33,9 +32,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = { "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", } -if choose_torch_device() == torch.device("mps"): - from torch import mps - @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -115,9 +111,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() image_dto = context.images.save(image=pil_image) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 4b5a2004be..496988e853 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0 DEFAULT_VRAM_CACHE = 0.25 DEFAULT_CONVERT_CACHE = 20.0 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] -PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] +PRECISION = Literal["auto", "float16", "bfloat16", "float32"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] -CONFIG_SCHEMA_VERSION = "4.0.0" +CONFIG_SCHEMA_VERSION = "4.0.1" def get_default_ram_cache_size() -> float: @@ -106,7 +106,7 @@ class InvokeAIAppConfig(BaseSettings): lazy_offload: Keep models in VRAM until their space is needed. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` - precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast` + precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` @@ -377,6 +377,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used if k == "max_vram_cache_size" and "vram" not in category_dict: parsed_config_dict["vram"] = v + # autocast was removed in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" if k == "conf_path": parsed_config_dict["legacy_models_yaml_path"] = v if k == "legacy_conf_dir": @@ -399,6 +402,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: return config +def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: + """Migrate v4.0.0 config dictionary to a current config object. + + Args: + config_dict: A dictionary of settings from a v4.0.0 config file. + + Returns: + An instance of `InvokeAIAppConfig` with the migrated settings. + """ + parsed_config_dict: dict[str, Any] = {} + for k, v in config_dict.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + else: + parsed_config_dict[k] = v + if k == "schema_version": + parsed_config_dict[k] = CONFIG_SCHEMA_VERSION + config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) + return config + + def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -425,17 +450,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e migrated_config.write_file(config_path) return migrated_config - else: - # Attempt to load as a v4 config file - try: - # Meta is not included in the model fields, so we need to validate it separately - config = InvokeAIAppConfig.model_validate(loaded_config_dict) - assert ( - config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" - return config - except Exception as e: - raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e + + if loaded_config_dict["schema_version"] == "4.0.0": + loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict) + loaded_config_dict.write_file(config_path) + + # Attempt to load as a v4 config file + try: + # Meta is not included in the model fields, so we need to validate it separately + config = InvokeAIAppConfig.model_validate(loaded_config_dict) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" + return config + except Exception as e: + raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e @lru_cache(maxsize=1) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index f1fbcdb7ba..f0e25648bf 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp from typing import Any, Dict, List, Optional, Union +import torch import yaml from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl @@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger -from invokeai.backend.util.devices import choose_precision, choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .model_install_base import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -643,11 +644,10 @@ class ModelInstallService(ModelInstallServiceBase): self._next_job_id += 1 return id - @staticmethod - def _guess_variant() -> Optional[ModelRepoVariant]: + def _guess_variant(self) -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" - precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else None + precision = TorchDevice.choose_torch_dtype() + return ModelRepoVariant.FP16 if precision == torch.float16 else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index de6e5f09d8..1a2b9a3402 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,12 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +from typing import Optional + import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase): model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, - execution_device: torch.device = choose_torch_device(), + execution_device: Optional[torch.device] = None, ) -> Self: """ Construct the model manager service instance. @@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase): max_vram_cache_size=app_config.vram, lazy_offloading=app_config.lazy_offload, logger=logger, - execution_device=execution_device, + execution_device=execution_device or TorchDevice.choose_torch_device(), ) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) loader = ModelLoadService( diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index 560d977b55..2d88c45485 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -12,7 +12,7 @@ from invokeai.app.services.config.config_default import get_config from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger config = get_config() @@ -47,7 +47,7 @@ class DepthAnythingDetector: self.context = context self.model: Optional[DPT_DINOv2] = None self.model_size: Union[Literal["large", "base", "small"], None] = None - self.device = choose_torch_device() + self.device = TorchDevice.choose_torch_device() def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) @@ -68,7 +68,7 @@ class DepthAnythingDetector: self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu")) self.model.eval() - self.model.to(choose_torch_device()) + self.model.to(self.device) return self.model def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: @@ -81,7 +81,7 @@ class DepthAnythingDetector: image_height, image_width = np_image.shape[:2] np_image = transform({"image": np_image})["image"] - tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device()) + tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device) with torch.no_grad(): depth = self.model(tensor_image) diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 3628b0abd5..0f66af2c77 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -4,11 +4,10 @@ import numpy as np import onnxruntime as ort -import torch from invokeai.app.services.config.config_default import get_config from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .onnxdet import inference_detector from .onnxpose import inference_pose @@ -23,9 +22,9 @@ config = get_config() class Wholebody: def __init__(self, context: InvocationContext): - device = choose_torch_device() + device = TorchDevice.choose_torch_device() - providers = ["CUDAExecutionProvider"] if device == torch.device("cuda") else ["CPUExecutionProvider"] + providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] onnx_det = context.models.download_and_cache_ckpt(DWPOSE_MODELS["yolox_l.onnx"]) onnx_pose = context.models.download_and_cache_ckpt(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]) diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index 7c4d90f5bd..c5fe3fa598 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -11,7 +11,7 @@ from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.model_manager.config import AnyModel -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice """ Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py @@ -65,7 +65,7 @@ class RealESRGAN: self.pre_pad = pre_pad self.mod_scale: Optional[int] = None self.half = half - self.device = choose_torch_device() + self.device = TorchDevice.choose_torch_device() # prefer to use params_ema if "params_ema" in loadnet: diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 7bceae8da7..60dcd93fcc 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" @@ -51,7 +51,7 @@ class SafetyChecker: cls._load_safety_checker() if cls.safety_checker is None or cls.feature_extractor is None: return False - device = choose_torch_device() + device = TorchDevice.choose_torch_device() features = cls.feature_extractor([image], return_tensors="pt") features.to(device) cls.safety_checker.to(device) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 451770c0cb..459808b455 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init -from invokeai.backend.util.devices import choose_torch_device, torch_dtype +from invokeai.backend.util.devices import TorchDevice # TO DO: The loader is not thread safe! @@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase): self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - self._torch_dtype = torch_dtype(choose_torch_device()) + self._torch_dtype = TorchDevice.choose_torch_dtype() def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 919a7c4396..62bb766cd6 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -31,15 +31,12 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_locker import ModelLocker -if choose_torch_device() == torch.device("mps"): - from torch import mps - # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously DEFAULT_MAX_CACHE_SIZE = 6.0 @@ -245,9 +242,7 @@ class ModelCache(ModelCacheBase[AnyModel]): f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" ) - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device. @@ -417,10 +412,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self.stats.cleared = models_cleared gc.collect() - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - + TorchDevice.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index eb6fd45e1a..125e99be93 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.util.devices import choose_torch_device, torch_dtype +from invokeai.backend.util.devices import TorchDevice from . import ( AnyModelConfig, @@ -43,6 +43,7 @@ class ModelMerger(object): Initialize a ModelMerger object with the model installer. """ self._installer = installer + self._dtype = TorchDevice.choose_torch_dtype() def merge_diffusion_models( self, @@ -68,7 +69,7 @@ class ModelMerger(object): warnings.simplefilter("ignore") verbosity = dlogging.get_verbosity() dlogging.set_verbosity_error() - dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) + dtype = torch.float16 if variant == "fp16" else self._dtype # Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models # until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released. @@ -151,7 +152,7 @@ class ModelMerger(object): dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name - dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) + dtype = torch.float16 if variant == "fp16" else self._dtype merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant) # register model and get its unique key diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index b4d1b3381c..bd60b0b8c7 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.attention import auto_detect_slice_size -from invokeai.backend.util.devices import normalize_device +from invokeai.backend.util.devices import TorchDevice @dataclass @@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.unet.device.type == "cpu" or self.unet.device.type == "mps": mem_free = psutil.virtual_memory().free elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device)) + mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device)) else: raise ValueError(f"unrecognized device {self.unet.device}") # input tensor of [1, 4, h/8, w/8] diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 2c9cceff2c..1e4d467cd0 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -2,7 +2,6 @@ Initialization file for invokeai.backend.util """ -from .devices import choose_precision, choose_torch_device from .logging import InvokeAILogger from .util import GIG, Chdir, directory_size @@ -11,6 +10,4 @@ __all__ = [ "directory_size", "Chdir", "InvokeAILogger", - "choose_precision", - "choose_torch_device", ] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index cb6b93eaac..e8380dc8bc 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,89 +1,110 @@ -from __future__ import annotations - -from contextlib import nullcontext -from typing import Literal, Optional, Union +from typing import Dict, Literal, Optional, Union import torch -from torch import autocast +from deprecated import deprecated -from invokeai.app.services.config.config_default import PRECISION, get_config +from invokeai.app.services.config.config_default import get_config +# legacy APIs +TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] CPU_DEVICE = torch.device("cpu") CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") +@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore +def choose_precision(device: torch.device) -> TorchPrecisionNames: + """Return the string representation of the recommended torch device.""" + torch_dtype = TorchDevice.choose_torch_dtype(device) + return PRECISION_TO_NAME[torch_dtype] + + +@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore def choose_torch_device() -> torch.device: - """Convenience routine for guessing which GPU device to run model on""" - config = get_config() - if config.device == "auto": - if torch.cuda.is_available(): - return torch.device("cuda") - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - return torch.device("mps") + """Return the torch.device to use for accelerated inference.""" + return TorchDevice.choose_torch_device() + + +@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore +def torch_dtype(device: torch.device) -> torch.dtype: + """Return the torch precision for the recommended torch device.""" + return TorchDevice.choose_torch_dtype(device) + + +NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, +} +PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()} + + +class TorchDevice: + """Abstraction layer for torch devices.""" + + @classmethod + def choose_torch_device(cls) -> torch.device: + """Return the torch.device to use for accelerated inference.""" + app_config = get_config() + if app_config.device != "auto": + device = torch.device(app_config.device) + elif torch.cuda.is_available(): + device = CUDA_DEVICE + elif torch.backends.mps.is_available(): + device = MPS_DEVICE else: - return CPU_DEVICE - else: - return torch.device(config.device) + device = CPU_DEVICE + return cls.normalize(device) + @classmethod + def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: + """Return the precision to use for accelerated inference.""" + device = device or cls.choose_torch_device() + config = get_config() + if device.type == "cuda" and torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(device) + if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: + # These GPUs have limited support for float16 + return cls._to_dtype("float32") + elif config.precision == "auto": + # Default to float16 for CUDA devices + return cls._to_dtype("float16") + else: + # Use the user-defined precision + return cls._to_dtype(config.precision) -def get_torch_device_name() -> str: - device = choose_torch_device() - return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + elif device.type == "mps" and torch.backends.mps.is_available(): + if config.precision == "auto": + # Default to float16 for MPS devices + return cls._to_dtype("float16") + else: + # Use the user-defined precision + return cls._to_dtype(config.precision) + # CPU / safe fallback + return cls._to_dtype("float32") + @classmethod + def get_torch_device_name(cls) -> str: + """Return the device name for the current torch device.""" + device = cls.choose_torch_device() + return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() -def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]: - """Return an appropriate precision for the given torch device.""" - app_config = get_config() - if device.type == "cuda": - device_name = torch.cuda.get_device_name(device) - if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: - # These GPUs have limited support for float16 - return "float32" - elif app_config.precision == "auto" or app_config.precision == "autocast": - # Default to float16 for CUDA devices - return "float16" - else: - # Use the user-defined precision - return app_config.precision - elif device.type == "mps": - if app_config.precision == "auto" or app_config.precision == "autocast": - # Default to float16 for MPS devices - return "float16" - else: - # Use the user-defined precision - return app_config.precision - # CPU / safe fallback - return "float32" - - -def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: - device = device or choose_torch_device() - precision = choose_precision(device) - if precision == "float16": - return torch.float16 - if precision == "bfloat16": - return torch.bfloat16 - else: - # "auto", "autocast", "float32" - return torch.float32 - - -def choose_autocast(precision: PRECISION): - """Returns an autocast context or nullcontext for the given precision string""" - # float16 currently requires autocast to avoid errors like: - # 'expected scalar type Half but found Float' - if precision == "autocast" or precision == "float16": - return autocast - return nullcontext - - -def normalize_device(device: Union[str, torch.device]) -> torch.device: - """Ensure device has a device index defined, if appropriate.""" - device = torch.device(device) - if device.index is None: - # cuda might be the only torch backend that currently uses the device index? - # I don't see anything like `current_device` for cpu or mps. - if device.type == "cuda": + @classmethod + def normalize(cls, device: Union[str, torch.device]) -> torch.device: + """Add the device index to CUDA devices.""" + device = torch.device(device) + if device.index is None and device.type == "cuda" and torch.cuda.is_available(): device = torch.device(device.type, torch.cuda.current_device()) - return device + return device + + @classmethod + def empty_cache(cls) -> None: + """Clear the GPU device cache.""" + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @classmethod + def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: + return NAME_TO_PRECISION[precision_name] diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 2f4e70005f..1adac4c5dc 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -770,6 +770,8 @@ "float": "Float", "fullyContainNodes": "Fully Contain Nodes to Select", "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", + "showEdgeLabels": "Show Edge Labels", + "showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes", "hideLegendNodes": "Hide Field Type Legend", "hideMinimapnodes": "Hide MiniMap", "inputMayOnlyHaveOneConnection": "Input may only have one connection", diff --git a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts index 372452ce99..bbb7897575 100644 --- a/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts +++ b/invokeai/frontend/web/src/common/hooks/useGlobalHotkeys.ts @@ -9,7 +9,7 @@ import { useHotkeys } from 'react-hotkeys-hook'; export const useGlobalHotkeys = () => { const dispatch = useAppDispatch(); - const isModelManagerEnabled = useFeatureStatus('modelManager').isFeatureEnabled; + const isModelManagerEnabled = useFeatureStatus('modelManager'); const { queueBack, isDisabled: isDisabledQueueBack, isLoading: isLoadingQueueBack } = useQueueBack(); useHotkeys( diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index ad6c37532e..a67fd5f82d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -32,7 +32,7 @@ const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd); const boardName = useBoardName(board_id); - const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled; + const isBulkDownloadEnabled = useFeatureStatus('bulkDownload'); const [bulkDownload] = useBulkDownloadImagesMutation(); diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx index 15ab74f44d..880fdbca6c 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImageButtons.tsx @@ -54,7 +54,7 @@ const CurrentImageButtons = () => { const selection = useAppSelector((s) => s.gallery.selection); const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); - const isUpscalingEnabled = useFeatureStatus('upscaling').isFeatureEnabled; + const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); const toaster = useAppToaster(); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx index 7b1fa73472..54fb19c844 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems.tsx @@ -20,7 +20,7 @@ const MultipleSelectionMenuItems = () => { const selection = useAppSelector((s) => s.gallery.selection); const customStarUi = useStore($customStarUI); - const isBulkDownloadEnabled = useFeatureStatus('bulkDownload').isFeatureEnabled; + const isBulkDownloadEnabled = useFeatureStatus('bulkDownload'); const [starImages] = useStarImagesMutation(); const [unstarImages] = useUnstarImagesMutation(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index 7f43ab3671..aff74481ca 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -45,7 +45,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const toaster = useAppToaster(); - const isCanvasEnabled = useFeatureStatus('unifiedCanvas').isFeatureEnabled; + const isCanvasEnabled = useFeatureStatus('unifiedCanvas'); const customStarUi = useStore($customStarUI); const { downloadImage } = useDownloadImage(); diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts index 05e7d075e5..f84a349d2a 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts @@ -18,7 +18,7 @@ export const useMultiselect = (imageDTO?: ImageDTO) => { [imageDTO?.image_name] ); const isSelected = useAppSelector(selectIsSelected); - const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled; + const isMultiSelectEnabled = useFeatureStatus('multiselect'); const handleClick = useCallback( (e: MouseEvent) => { diff --git a/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx b/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx index 2cb96a935f..eaa6d60dda 100644 --- a/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx +++ b/invokeai/frontend/web/src/features/hrf/components/HrfSettings.tsx @@ -8,7 +8,7 @@ import ParamHrfStrength from './ParamHrfStrength'; import ParamHrfToggle from './ParamHrfToggle'; export const HrfSettings = memo(() => { - const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled; + const isHRFFeatureEnabled = useFeatureStatus('hrf'); const hrfEnabled = useAppSelector((s) => s.hrf.hrfEnabled); if (!isHRFFeatureEnabled) { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx index 0ee06f19dd..6abc633ac8 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx @@ -10,7 +10,7 @@ const TOAST_ID = 'starterModels'; export const useStarterModelsToast = () => { const { t } = useTranslation(); - const isEnabled = useFeatureStatus('starterModels').isFeatureEnabled; + const isEnabled = useFeatureStatus('starterModels'); const [didToast, setDidToast] = useState(false); const [mainModels, { data }] = useMainModels(); const toast = useToast(); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx index 0a18c2a959..77be60a945 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx @@ -1,8 +1,9 @@ +import { Flex, Text } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; -import { BaseEdge, getBezierPath } from 'reactflow'; +import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; import { makeEdgeSelector } from './util/makeEdgeSelector'; @@ -25,9 +26,10 @@ const InvocationDefaultEdge = ({ [source, sourceHandleId, target, targetHandleId, selected] ); - const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); + const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); + const shouldShowEdgeLabels = useAppSelector((s) => s.nodes.shouldShowEdgeLabels); - const [edgePath] = getBezierPath({ + const [edgePath, labelX, labelY] = getBezierPath({ sourceX, sourceY, sourcePosition, @@ -47,7 +49,33 @@ const InvocationDefaultEdge = ({ [isSelected, shouldAnimate, stroke] ); - return ; + return ( + <> + + {label && shouldShowEdgeLabels && ( + + + + {label} + + + + )} + + ); }; export default memo(InvocationDefaultEdge); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index ba40b4984c..a485bf64c1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,7 +1,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; @@ -10,6 +10,7 @@ const defaultReturnValue = { isSelected: false, shouldAnimate: false, stroke: colorTokenToCssVar('base.500'), + label: '', }; export const makeEdgeSelector = ( @@ -19,25 +20,34 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { - const sourceNode = nodes.nodes.find((node) => node.id === source); - const targetNode = nodes.nodes.find((node) => node.id === target); + createMemoizedSelector( + selectNodesSlice, + (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => { + const sourceNode = nodes.nodes.find((node) => node.id === source); + const targetNode = nodes.nodes.find((node) => node.id === target); - const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); + const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); - if (!sourceNode || !sourceHandleId) { - return defaultReturnValue; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; + + const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); + + const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id); + const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id); + + const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + + return { + isSelected, + shouldAnimate: nodes.shouldAnimateEdges && isSelected, + stroke, + label, + }; } - - const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); - const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; - - const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); - - return { - isSelected, - shouldAnimate: nodes.shouldAnimateEdges && isSelected, - stroke, - }; - }); + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx index 1b93c0fdd3..c1ff625d25 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeFooter.tsx @@ -16,7 +16,7 @@ const props: ChakraProps = { w: 'unset' }; const InvocationNodeFooter = ({ nodeId }: Props) => { const hasImageOutput = useHasImageOutput(nodeId); - const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; + const isCacheEnabled = useFeatureStatus('invocationCache'); return ( { - const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionMode } = nodes; + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + shouldShowEdgeLabels, + selectionMode, + } = nodes; return { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, + shouldShowEdgeLabels, selectionModeIsChecked: selectionMode === SelectionMode.Full, }; }); @@ -52,8 +61,14 @@ type Props = { const WorkflowEditorSettings = ({ children }: Props) => { const { isOpen, onOpen, onClose } = useDisclosure(); const dispatch = useAppDispatch(); - const { shouldAnimateEdges, shouldValidateGraph, shouldSnapToGrid, shouldColorEdges, selectionModeIsChecked } = - useAppSelector(selector); + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + shouldShowEdgeLabels, + selectionModeIsChecked, + } = useAppSelector(selector); const handleChangeShouldValidate = useCallback( (e: ChangeEvent) => { @@ -90,6 +105,13 @@ const WorkflowEditorSettings = ({ children }: Props) => { [dispatch] ); + const handleChangeShouldShowEdgeLabels = useCallback( + (e: ChangeEvent) => { + dispatch(shouldShowEdgeLabelsChanged(e.target.checked)); + }, + [dispatch] + ); + const { t } = useTranslation(); return ( @@ -137,6 +159,14 @@ const WorkflowEditorSettings = ({ children }: Props) => { {t('nodes.fullyContainNodesHelp')} + + + {t('nodes.showEdgeLabels')} + + + {t('nodes.showEdgeLabelsHelp')} + + {t('common.advanced')} diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e16b329c22..54c092370b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,5 +1,5 @@ -import { createSelector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodeTemplate } from 'features/nodes/store/selectors'; @@ -10,7 +10,7 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes) => { const template = selectNodeTemplate(nodes, nodeId); if (!template) { return EMPTY_ARRAY; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts index b3676f9722..6e00d374f6 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWithFooter.ts @@ -5,8 +5,7 @@ import { useHasImageOutput } from './useHasImageOutput'; export const useWithFooter = (nodeId: string) => { const hasImageOutput = useHasImageOutput(nodeId); - const isCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; - + const isCacheEnabled = useFeatureStatus('invocationCache'); const withFooter = useMemo(() => hasImageOutput || isCacheEnabled, [hasImageOutput, isCacheEnabled]); return withFooter; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 4a1b438271..0f0417cf71 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -103,6 +103,7 @@ const initialNodesState: NodesState = { shouldAnimateEdges: true, shouldSnapToGrid: false, shouldColorEdges: true, + shouldShowEdgeLabels: false, isAddNodePopoverOpen: false, nodeOpacity: 1, selectedNodes: [], @@ -549,6 +550,9 @@ export const nodesSlice = createSlice({ shouldAnimateEdgesChanged: (state, action: PayloadAction) => { state.shouldAnimateEdges = action.payload; }, + shouldShowEdgeLabelsChanged: (state, action: PayloadAction) => { + state.shouldShowEdgeLabels = action.payload; + }, shouldSnapToGridChanged: (state, action: PayloadAction) => { state.shouldSnapToGrid = action.payload; }, @@ -831,6 +835,7 @@ export const { viewportChanged, edgeAdded, nodeTemplatesBuilt, + shouldShowEdgeLabelsChanged, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 3f7d210dcb..2074f1f342 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -32,6 +32,7 @@ export type NodesState = { isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; selectionMode: SelectionMode; + shouldShowEdgeLabels: boolean; }; export type WorkflowMode = 'edit' | 'view'; diff --git a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx index 1cafe4310e..be173ca6ca 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillColorOptions.tsx @@ -1,24 +1,18 @@ import { Box, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import { createSelector } from '@reduxjs/toolkit'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIColorPicker from 'common/components/IAIColorPicker'; import { selectGenerationSlice, setInfillColorValue } from 'features/parameters/store/generationSlice'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import type { RgbaColor } from 'react-colorful'; import { useTranslation } from 'react-i18next'; +const selectInfillColor = createMemoizedSelector(selectGenerationSlice, (generation) => generation.infillColorValue); + const ParamInfillColorOptions = () => { const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector(selectGenerationSlice, (generation) => ({ - infillColor: generation.infillColorValue, - })), - [] - ); - - const { infillColor } = useAppSelector(selector); + const infillColor = useAppSelector(selectInfillColor); const infillMethod = useAppSelector((s) => s.generation.infillMethod); diff --git a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx index f164bb903e..cfdd7fb010 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Canvas/InfillAndScaling/ParamInfillMosaicOptions.tsx @@ -1,35 +1,23 @@ import { Box, CompositeNumberInput, CompositeSlider, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library'; -import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIColorPicker from 'common/components/IAIColorPicker'; import { - selectGenerationSlice, setInfillMosaicMaxColor, setInfillMosaicMinColor, setInfillMosaicTileHeight, setInfillMosaicTileWidth, } from 'features/parameters/store/generationSlice'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; import type { RgbaColor } from 'react-colorful'; import { useTranslation } from 'react-i18next'; const ParamInfillMosaicTileSize = () => { const dispatch = useAppDispatch(); - const selector = useMemo( - () => - createSelector(selectGenerationSlice, (generation) => ({ - infillMosaicTileWidth: generation.infillMosaicTileWidth, - infillMosaicTileHeight: generation.infillMosaicTileHeight, - infillMosaicMinColor: generation.infillMosaicMinColor, - infillMosaicMaxColor: generation.infillMosaicMaxColor, - })), - [] - ); - - const { infillMosaicTileWidth, infillMosaicTileHeight, infillMosaicMinColor, infillMosaicMaxColor } = - useAppSelector(selector); - + const infillMosaicTileWidth = useAppSelector((s) => s.generation.infillMosaicTileWidth); + const infillMosaicTileHeight = useAppSelector((s) => s.generation.infillMosaicTileHeight); + const infillMosaicMinColor = useAppSelector((s) => s.generation.infillMosaicMinColor); + const infillMosaicMaxColor = useAppSelector((s) => s.generation.infillMosaicMaxColor); const infillMethod = useAppSelector((s) => s.generation.infillMethod); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx b/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx index 37b6318a53..101c82376c 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueActionsMenuButton.tsx @@ -27,8 +27,8 @@ export const QueueActionsMenuButton = memo(() => { const dispatch = useAppDispatch(); const { t } = useTranslation(); const clearQueueDisclosure = useDisclosure(); - const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled; - const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled; + const isPauseEnabled = useFeatureStatus('pauseQueue'); + const isResumeEnabled = useFeatureStatus('resumeQueue'); const { queueSize } = useGetQueueStatusQuery(undefined, { selectFromResult: (res) => ({ queueSize: res.data ? res.data.queue.pending + res.data.queue.in_progress : 0, diff --git a/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx b/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx index 5499ecccfc..28a12808ea 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueControls.tsx @@ -9,7 +9,7 @@ import { InvokeQueueBackButton } from './InvokeQueueBackButton'; import { QueueActionsMenuButton } from './QueueActionsMenuButton'; const QueueControls = () => { - const isPrependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled; + const isPrependEnabled = useFeatureStatus('prependQueue'); return ( diff --git a/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx b/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx index d7c27d54fc..2dae5e6ebe 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueTabContent.tsx @@ -8,7 +8,7 @@ import QueueStatus from './QueueStatus'; import QueueTabQueueControls from './QueueTabQueueControls'; const QueueTabContent = () => { - const isInvocationCacheEnabled = useFeatureStatus('invocationCache').isFeatureEnabled; + const isInvocationCacheEnabled = useFeatureStatus('invocationCache'); return ( diff --git a/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx b/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx index b42fc04e4d..3aed2f237a 100644 --- a/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx +++ b/invokeai/frontend/web/src/features/queue/components/QueueTabQueueControls.tsx @@ -8,8 +8,8 @@ import PruneQueueButton from './PruneQueueButton'; import ResumeProcessorButton from './ResumeProcessorButton'; const QueueTabQueueControls = () => { - const isPauseEnabled = useFeatureStatus('pauseQueue').isFeatureEnabled; - const isResumeEnabled = useFeatureStatus('resumeQueue').isFeatureEnabled; + const isPauseEnabled = useFeatureStatus('pauseQueue'); + const isResumeEnabled = useFeatureStatus('resumeQueue'); return ( {isPauseEnabled || isResumeEnabled ? ( diff --git a/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts b/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts index d39ac96566..f6c71dbc5a 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useQueueFront.ts @@ -13,7 +13,7 @@ export const useQueueFront = () => { const [_, { isLoading }] = useEnqueueBatchMutation({ fixedCacheKey: 'enqueueBatch', }); - const prependEnabled = useFeatureStatus('prependQueue').isFeatureEnabled; + const prependEnabled = useFeatureStatus('prependQueue'); const isDisabled = useMemo(() => { return !isReady || !prependEnabled; diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx index 3cddc927e9..ec81b0b211 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/ControlSettingsAccordion/ControlSettingsAccordion.tsx @@ -62,7 +62,7 @@ const selector = createMemoizedSelector(selectControlAdaptersSlice, (controlAdap export const ControlSettingsAccordion: React.FC = memo(() => { const { t } = useTranslation(); const { controlAdapterIds, badges } = useAppSelector(selector); - const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled; + const isControlNetEnabled = useFeatureStatus('controlNet'); const { isOpen, onToggle } = useStandaloneAccordionToggle({ id: 'control-settings', defaultIsOpen: true, @@ -71,7 +71,7 @@ export const ControlSettingsAccordion: React.FC = memo(() => { const [addIPAdapter, isAddIPAdapterDisabled] = useAddControlAdapter('ip_adapter'); const [addT2IAdapter, isAddT2IAdapterDisabled] = useAddControlAdapter('t2i_adapter'); - if (isControlNetDisabled) { + if (!isControlNetEnabled) { return null; } diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx index eceba85b5a..bee5940530 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsLanguageSelect.tsx @@ -40,7 +40,7 @@ export const SettingsLanguageSelect = memo(() => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const language = useAppSelector((s) => s.system.language); - const isLocalizationEnabled = useFeatureStatus('localization').isFeatureEnabled; + const isLocalizationEnabled = useFeatureStatus('localization'); const value = useMemo(() => options.find((o) => o.value === language), [language]); diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx index 6cf37334e7..b424a129ee 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsMenu.tsx @@ -23,9 +23,9 @@ const SettingsMenu = () => { const { isOpen, onOpen, onClose } = useDisclosure(); useGlobalMenuClose(onClose); - const isBugLinkEnabled = useFeatureStatus('bugLink').isFeatureEnabled; - const isDiscordLinkEnabled = useFeatureStatus('discordLink').isFeatureEnabled; - const isGithubLinkEnabled = useFeatureStatus('githubLink').isFeatureEnabled; + const isBugLinkEnabled = useFeatureStatus('bugLink'); + const isDiscordLinkEnabled = useFeatureStatus('discordLink'); + const isGithubLinkEnabled = useFeatureStatus('githubLink'); return ( diff --git a/invokeai/frontend/web/src/features/system/hooks/useFeatureStatus.ts b/invokeai/frontend/web/src/features/system/hooks/useFeatureStatus.ts index 2c13c06b36..527405cb7d 100644 --- a/invokeai/frontend/web/src/features/system/hooks/useFeatureStatus.ts +++ b/invokeai/frontend/web/src/features/system/hooks/useFeatureStatus.ts @@ -1,32 +1,24 @@ +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import type { AppFeature, SDFeature } from 'app/types/invokeai'; +import { selectConfigSlice } from 'features/system/store/configSlice'; import type { InvokeTabName } from 'features/ui/store/tabMap'; import { useMemo } from 'react'; export const useFeatureStatus = (feature: AppFeature | SDFeature | InvokeTabName) => { - const disabledTabs = useAppSelector((s) => s.config.disabledTabs); - - const disabledFeatures = useAppSelector((s) => s.config.disabledFeatures); - - const disabledSDFeatures = useAppSelector((s) => s.config.disabledSDFeatures); - - const isFeatureDisabled = useMemo( + const selectIsFeatureEnabled = useMemo( () => - disabledFeatures.includes(feature as AppFeature) || - disabledSDFeatures.includes(feature as SDFeature) || - disabledTabs.includes(feature as InvokeTabName), - [disabledFeatures, disabledSDFeatures, disabledTabs, feature] + createSelector(selectConfigSlice, (config) => { + return !( + config.disabledFeatures.includes(feature as AppFeature) || + config.disabledSDFeatures.includes(feature as SDFeature) || + config.disabledTabs.includes(feature as InvokeTabName) + ); + }), + [feature] ); - const isFeatureEnabled = useMemo( - () => - !( - disabledFeatures.includes(feature as AppFeature) || - disabledSDFeatures.includes(feature as SDFeature) || - disabledTabs.includes(feature as InvokeTabName) - ), - [disabledFeatures, disabledSDFeatures, disabledTabs, feature] - ); + const isFeatureEnabled = useAppSelector(selectIsFeatureEnabled); - return { isFeatureDisabled, isFeatureEnabled }; + return isFeatureEnabled; }; diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py new file mode 100644 index 0000000000..8e810e4367 --- /dev/null +++ b/tests/backend/util/test_devices.py @@ -0,0 +1,132 @@ +""" +Test abstract device class. +""" + +from unittest.mock import patch + +import pytest +import torch + +from invokeai.app.services.config import get_config +from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype + +devices = ["cpu", "cuda:0", "cuda:1", "mps"] +device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)] +device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)] +device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)] + + +@pytest.mark.parametrize("device_name", devices) +def test_device_choice(device_name): + config = get_config() + config.device = device_name + torch_device = TorchDevice.choose_torch_device() + assert torch_device == torch.device(device_name) + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cpu) +def test_device_dtype_cpu(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cuda) +def test_device_dtype_cuda(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=True), + patch("torch.cuda.get_device_name", return_value="RTX4070"), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_mps) +def test_device_dtype_mps(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=True), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cuda) +def test_device_dtype_override(device_dtype_pair): + with ( + patch("torch.cuda.get_device_name", return_value="RTX4070"), + patch("torch.cuda.is_available", return_value=True), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + config.precision = "float32" + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == torch.float32 + + +def test_normalize(): + assert ( + TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda") + ) + assert ( + TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda") + ) + assert ( + TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda") + ) + assert TorchDevice.normalize("mps") == torch.device("mps") + assert TorchDevice.normalize("cpu") == torch.device("cpu") + + +@pytest.mark.parametrize("device_name", devices) +def test_legacy_device_choice(device_name): + config = get_config() + config.device = device_name + with pytest.deprecated_call(): + torch_device = choose_torch_device() + assert torch_device == torch.device(device_name) + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cpu) +def test_legacy_device_dtype_cpu(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=False), + patch("torch.cuda.get_device_name", return_value="RTX9090"), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + with pytest.deprecated_call(): + torch_device = choose_torch_device() + returned_dtype = torch_dtype(torch_device) + assert returned_dtype == dtype + + +def test_legacy_precision_name(): + config = get_config() + config.precision = "auto" + with ( + pytest.deprecated_call(), + patch("torch.cuda.is_available", return_value=True), + patch("torch.backends.mps.is_available", return_value=True), + patch("torch.cuda.get_device_name", return_value="RTX9090"), + ): + assert "float16" == choose_precision(torch.device("cuda")) + assert "float16" == choose_precision(torch.device("mps")) + assert "float32" == choose_precision(torch.device("cpu"))