From e93f4d632dd1aed95bbaed79d2797a9360029c66 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 15 Apr 2024 09:12:49 -0400 Subject: [PATCH] [util] Add generic torch device class (#6174) * introduce new abstraction layer for GPU devices * add unit test for device abstraction * fix ruff * convert TorchDeviceSelect into a stateless class * move logic to select context-specific execution device into context API * add mock hardware environments to pytest * remove dangling mocker fixture * fix unit test for running on non-CUDA systems * remove unimplemented get_execution_device() call * remove autocast precision * Multiple changes: 1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to context.models.get_execution_device(). 2. Rename TorchDeviceSelect to TorchDevice 3. Added back the legacy public API defined in `invocation_api`, including choose_precision(). 4. Added a config file migration script to accommodate removal of precision=autocast. * add deprecation warnings to choose_torch_device() and choose_precision() * fix test crash * remove app_config argument from choose_torch_device() and choose_torch_dtype() --------- Co-authored-by: Lincoln Stein --- invokeai/app/api_app.py | 4 +- invokeai/app/invocations/compel.py | 6 +- invokeai/app/invocations/latent.py | 43 ++--- invokeai/app/invocations/noise.py | 8 +- invokeai/app/invocations/upscale.py | 10 +- .../app/services/config/config_default.py | 57 ++++-- .../model_install/model_install_default.py | 10 +- .../model_manager/model_manager_default.py | 8 +- .../image_util/depth_anything/__init__.py | 8 +- .../image_util/dw_openpose/wholebody.py | 6 +- .../backend/image_util/infill_methods/lama.py | 4 +- .../image_util/realesrgan/realesrgan.py | 4 +- invokeai/backend/image_util/safety_checker.py | 4 +- .../model_manager/load/load_default.py | 4 +- .../load/model_cache/model_cache_default.py | 14 +- invokeai/backend/model_manager/merge.py | 7 +- .../stable_diffusion/diffusers_pipeline.py | 4 +- invokeai/backend/util/__init__.py | 3 - invokeai/backend/util/devices.py | 167 ++++++++++-------- tests/backend/util/test_devices.py | 132 ++++++++++++++ 20 files changed, 327 insertions(+), 176 deletions(-) create mode 100644 tests/backend/util/test_devices.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 87eaefc020..ceaeb95147 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -28,7 +28,7 @@ from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.config.config_default import get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage -from invokeai.backend.util.devices import get_torch_device_name +from invokeai.backend.util.devices import TorchDevice from ..backend.util.logging import InvokeAILogger from .api.dependencies import ApiDependencies @@ -63,7 +63,7 @@ logger = InvokeAILogger.get_logger(config=app_config) mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("text/css", ".css") -torch_device_name = get_torch_device_name() +torch_device_name = TorchDevice.get_torch_device_name() logger.info(f"Using torch device: {torch_device_name}") diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 481a2d2e4b..158f11a58e 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -24,7 +24,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( ConditioningFieldData, SDXLConditioningInfo, ) -from invokeai.backend.util.devices import torch_dtype +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .model import CLIPField @@ -99,7 +99,7 @@ class CompelInvocation(BaseInvocation): tokenizer=tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, + dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, ) @@ -193,7 +193,7 @@ class SDXLPromptInvocationBase: tokenizer=tokenizer, text_encoder=text_encoder, textual_inversion_manager=ti_manager, - dtype_for_device_getter=torch_dtype, + dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ce63d568c6..a8ead96f3a 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -72,15 +72,12 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( image_resized_to_grid_as_tensor, ) from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import choose_precision, choose_torch_device +from ...backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .controlnet_image_processors import ControlField from .model import ModelIdentifierField, UNetField, VAEField -if choose_torch_device() == torch.device("mps"): - from torch import mps - -DEFAULT_PRECISION = choose_precision(choose_torch_device()) +DEFAULT_PRECISION = TorchDevice.choose_torch_dtype() @invocation_output("scheduler_output") @@ -959,9 +956,7 @@ class DenoiseLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) @@ -1028,9 +1023,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): vae.disable_tiling() # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() with torch.inference_mode(): # copied from diffusers pipeline @@ -1042,9 +1035,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): image = VaeImageProcessor.numpy_to_pil(np_image)[0] - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() image_dto = context.images.save(image=image) @@ -1083,9 +1074,7 @@ class ResizeLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.tensors.load(self.latents.latents_name) - - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() resized_latents = torch.nn.functional.interpolate( latents.to(device), @@ -1096,9 +1085,8 @@ class ResizeLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + + TorchDevice.empty_cache() name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1125,8 +1113,7 @@ class ScaleLatentsInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.tensors.load(self.latents.latents_name) - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() # resizing resized_latents = torch.nn.functional.interpolate( @@ -1138,9 +1125,7 @@ class ScaleLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() name = context.tensors.save(tensor=resized_latents) return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed) @@ -1272,8 +1257,7 @@ class BlendLatentsInvocation(BaseInvocation): if latents_a.shape != latents_b.shape: raise Exception("Latents to blend must be the same size.") - # TODO: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() def slerp( t: Union[float, npt.NDArray[Any]], # FIXME: maybe use np.float32 here? @@ -1326,9 +1310,8 @@ class BlendLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") - torch.cuda.empty_cache() - if device == torch.device("mps"): - mps.empty_cache() + + TorchDevice.empty_cache() name = context.tensors.save(tensor=blended_latents) return LatentsOutput.build(latents_name=name, latents=blended_latents) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index 6e5612b8b0..931e639106 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -9,7 +9,7 @@ from invokeai.app.invocations.fields import FieldDescriptions, InputField, Laten from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -from ...backend.util.devices import choose_torch_device, torch_dtype +from ...backend.util.devices import TorchDevice from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -46,7 +46,7 @@ def get_noise( height // downsampling_factor, width // downsampling_factor, ], - dtype=torch_dtype(device), + dtype=TorchDevice.choose_torch_dtype(device=device), device=noise_device_type, generator=generator, ).to("cpu") @@ -111,14 +111,14 @@ class NoiseInvocation(BaseInvocation): @field_validator("seed", mode="before") def modulo_seed(cls, v): - """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" + """Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" return v % (SEED_MAX + 1) def invoke(self, context: InvocationContext) -> NoiseOutput: noise = get_noise( width=self.width, height=self.height, - device=choose_torch_device(), + device=TorchDevice.choose_torch_device(), seed=self.seed, use_cpu=self.use_cpu, ) diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index d687384fcb..deaf5696c6 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,7 +4,6 @@ from typing import Literal import cv2 import numpy as np -import torch from PIL import Image from pydantic import ConfigDict @@ -14,7 +13,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .baseinvocation import BaseInvocation, invocation from .fields import InputField, WithBoard, WithMetadata @@ -35,9 +34,6 @@ ESRGAN_MODEL_URLS: dict[str, str] = { "RealESRGAN_x2plus.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", } -if choose_torch_device() == torch.device("mps"): - from torch import mps - @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -120,9 +116,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): upscaled_image = upscaler.upscale(cv2_image) pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() image_dto = context.images.save(image=pil_image) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index f453a56584..54a092d03e 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -27,12 +27,12 @@ DEFAULT_RAM_CACHE = 10.0 DEFAULT_VRAM_CACHE = 0.25 DEFAULT_CONVERT_CACHE = 20.0 DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"] -PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"] +PRECISION = Literal["auto", "float16", "bfloat16", "float32"] ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"] ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"] LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"] -CONFIG_SCHEMA_VERSION = "4.0.0" +CONFIG_SCHEMA_VERSION = "4.0.1" def get_default_ram_cache_size() -> float: @@ -105,7 +105,7 @@ class InvokeAIAppConfig(BaseSettings): lazy_offload: Keep models in VRAM until their space is needed. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` - precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast` + precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. attention_type: Attention type.
Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp` attention_slice_size: Slice size, valid when attention_type=="sliced".
Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8` @@ -370,6 +370,9 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used if k == "max_vram_cache_size" and "vram" not in category_dict: parsed_config_dict["vram"] = v + # autocast was removed in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" if k == "conf_path": parsed_config_dict["legacy_models_yaml_path"] = v if k == "legacy_conf_dir": @@ -392,6 +395,28 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: return config +def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig: + """Migrate v4.0.0 config dictionary to a current config object. + + Args: + config_dict: A dictionary of settings from a v4.0.0 config file. + + Returns: + An instance of `InvokeAIAppConfig` with the migrated settings. + """ + parsed_config_dict: dict[str, Any] = {} + for k, v in config_dict.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + else: + parsed_config_dict[k] = v + if k == "schema_version": + parsed_config_dict[k] = CONFIG_SCHEMA_VERSION + config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict) + return config + + def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: """Load and migrate a config file to the latest version. @@ -418,17 +443,21 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e migrated_config.write_file(config_path) return migrated_config - else: - # Attempt to load as a v4 config file - try: - # Meta is not included in the model fields, so we need to validate it separately - config = InvokeAIAppConfig.model_validate(loaded_config_dict) - assert ( - config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" - return config - except Exception as e: - raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e + + if loaded_config_dict["schema_version"] == "4.0.0": + loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict) + loaded_config_dict.write_file(config_path) + + # Attempt to load as a v4 config file + try: + # Meta is not included in the model fields, so we need to validate it separately + config = InvokeAIAppConfig.model_validate(loaded_config_dict) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" + return config + except Exception as e: + raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e @lru_cache(maxsize=1) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 20cfc1c4ff..5aa0f199fc 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -13,6 +13,7 @@ from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp from typing import Any, Dict, List, Optional, Union +import torch import yaml from huggingface_hub import HfFolder from pydantic.networks import AnyHttpUrl @@ -42,7 +43,7 @@ from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMet from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.util import InvokeAILogger -from invokeai.backend.util.devices import choose_precision, choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .model_install_base import ( MODEL_SOURCE_TO_TYPE_MAP, @@ -634,11 +635,10 @@ class ModelInstallService(ModelInstallServiceBase): self._next_job_id += 1 return id - @staticmethod - def _guess_variant() -> Optional[ModelRepoVariant]: + def _guess_variant(self) -> Optional[ModelRepoVariant]: """Guess the best HuggingFace variant type to download.""" - precision = choose_precision(choose_torch_device()) - return ModelRepoVariant.FP16 if precision == "float16" else None + precision = TorchDevice.choose_torch_dtype() + return ModelRepoVariant.FP16 if precision == torch.float16 else None def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob: return ModelInstallJob( diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index de6e5f09d8..1a2b9a3402 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -1,12 +1,14 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team """Implementation of ModelManagerServiceBase.""" +from typing import Optional + import torch from typing_extensions import Self from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from ..config import InvokeAIAppConfig @@ -67,7 +69,7 @@ class ModelManagerService(ModelManagerServiceBase): model_record_service: ModelRecordServiceBase, download_queue: DownloadQueueServiceBase, events: EventServiceBase, - execution_device: torch.device = choose_torch_device(), + execution_device: Optional[torch.device] = None, ) -> Self: """ Construct the model manager service instance. @@ -82,7 +84,7 @@ class ModelManagerService(ModelManagerServiceBase): max_vram_cache_size=app_config.vram, lazy_offloading=app_config.lazy_offload, logger=logger, - execution_device=execution_device, + execution_device=execution_device or TorchDevice.choose_torch_device(), ) convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache) loader = ModelLoadService( diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index ccac2ba949..c854fba3f2 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -13,7 +13,7 @@ from invokeai.app.services.config.config_default import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger config = get_config() @@ -56,7 +56,7 @@ class DepthAnythingDetector: def __init__(self) -> None: self.model = None self.model_size: Union[Literal["large", "base", "small"], None] = None - self.device = choose_torch_device() + self.device = TorchDevice.choose_torch_device() def load_model(self, model_size: Literal["large", "base", "small"] = "small"): DEPTH_ANYTHING_MODEL_PATH = config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"] @@ -81,7 +81,7 @@ class DepthAnythingDetector: self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) self.model.eval() - self.model.to(choose_torch_device()) + self.model.to(self.device) return self.model def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: @@ -94,7 +94,7 @@ class DepthAnythingDetector: image_height, image_width = np_image.shape[:2] np_image = transform({"image": np_image})["image"] - tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device()) + tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(self.device) with torch.no_grad(): depth = self.model(tensor_image) diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 35d340640d..84f5afa989 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -7,7 +7,7 @@ import onnxruntime as ort from invokeai.app.services.config.config_default import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from .onnxdet import inference_detector from .onnxpose import inference_pose @@ -28,9 +28,9 @@ config = get_config() class Wholebody: def __init__(self): - device = choose_torch_device() + device = TorchDevice.choose_torch_device() - providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"] + providers = ["CUDAExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] DET_MODEL_PATH = config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"] download_with_progress_bar("yolox_l.onnx", DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index fa354aeed1..4268ec773d 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -8,7 +8,7 @@ from PIL import Image import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice def norm_img(np_img): @@ -29,7 +29,7 @@ def load_jit_model(url_or_path, device): class LaMA: def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: - device = choose_torch_device() + device = TorchDevice.choose_torch_device() model_location = get_config().models_path / "core/misc/lama/lama.pt" if not model_location.exists(): diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py index c06504b608..663a323967 100644 --- a/invokeai/backend/image_util/realesrgan/realesrgan.py +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -11,7 +11,7 @@ from cv2.typing import MatLike from tqdm import tqdm from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice """ Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py @@ -65,7 +65,7 @@ class RealESRGAN: self.pre_pad = pre_pad self.mod_scale: Optional[int] = None self.half = half - self.device = choose_torch_device() + self.device = TorchDevice.choose_torch_device() loadnet = torch.load(model_path, map_location=torch.device("cpu")) diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 7bceae8da7..60dcd93fcc 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -13,7 +13,7 @@ from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger from invokeai.app.services.config.config_default import get_config -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" @@ -51,7 +51,7 @@ class SafetyChecker: cls._load_safety_checker() if cls.safety_checker is None or cls.feature_extractor is None: return False - device = choose_torch_device() + device = TorchDevice.choose_torch_device() features = cls.feature_extractor([image], return_tensors="pt") features.to(device) cls.safety_checker.to(device) diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 6774fc2989..a58741763f 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -18,7 +18,7 @@ from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoad from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init -from invokeai.backend.util.devices import choose_torch_device, torch_dtype +from invokeai.backend.util.devices import TorchDevice # TO DO: The loader is not thread safe! @@ -37,7 +37,7 @@ class ModelLoader(ModelLoaderBase): self._logger = logger self._ram_cache = ram_cache self._convert_cache = convert_cache - self._torch_dtype = torch_dtype(choose_torch_device()) + self._torch_dtype = TorchDevice.choose_torch_dtype() def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 2ba52d466c..2ffe954e11 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -30,15 +30,12 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff -from invokeai.backend.util.devices import choose_torch_device +from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase from .model_locker import ModelLocker -if choose_torch_device() == torch.device("mps"): - from torch import mps - # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously DEFAULT_MAX_CACHE_SIZE = 6.0 @@ -244,9 +241,7 @@ class ModelCache(ModelCacheBase[AnyModel]): f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" ) - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() + TorchDevice.empty_cache() def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: """Move model into the indicated device. @@ -416,10 +411,7 @@ class ModelCache(ModelCacheBase[AnyModel]): self.stats.cleared = models_cleared gc.collect() - torch.cuda.empty_cache() - if choose_torch_device() == torch.device("mps"): - mps.empty_cache() - + TorchDevice.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index eb6fd45e1a..125e99be93 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -17,7 +17,7 @@ from diffusers.utils import logging as dlogging from invokeai.app.services.model_install import ModelInstallServiceBase from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.util.devices import choose_torch_device, torch_dtype +from invokeai.backend.util.devices import TorchDevice from . import ( AnyModelConfig, @@ -43,6 +43,7 @@ class ModelMerger(object): Initialize a ModelMerger object with the model installer. """ self._installer = installer + self._dtype = TorchDevice.choose_torch_dtype() def merge_diffusion_models( self, @@ -68,7 +69,7 @@ class ModelMerger(object): warnings.simplefilter("ignore") verbosity = dlogging.get_verbosity() dlogging.set_verbosity_error() - dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) + dtype = torch.float16 if variant == "fp16" else self._dtype # Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models # until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released. @@ -151,7 +152,7 @@ class ModelMerger(object): dump_path.mkdir(parents=True, exist_ok=True) dump_path = dump_path / merged_model_name - dtype = torch.float16 if variant == "fp16" else torch_dtype(choose_torch_device()) + dtype = torch.float16 if variant == "fp16" else self._dtype merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant) # register model and get its unique key diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index b4d1b3381c..bd60b0b8c7 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher from invokeai.backend.util.attention import auto_detect_slice_size -from invokeai.backend.util.devices import normalize_device +from invokeai.backend.util.devices import TorchDevice @dataclass @@ -258,7 +258,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if self.unet.device.type == "cpu" or self.unet.device.type == "mps": mem_free = psutil.virtual_memory().free elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(normalize_device(self.unet.device)) + mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device)) else: raise ValueError(f"unrecognized device {self.unet.device}") # input tensor of [1, 4, h/8, w/8] diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index 2c9cceff2c..1e4d467cd0 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -2,7 +2,6 @@ Initialization file for invokeai.backend.util """ -from .devices import choose_precision, choose_torch_device from .logging import InvokeAILogger from .util import GIG, Chdir, directory_size @@ -11,6 +10,4 @@ __all__ = [ "directory_size", "Chdir", "InvokeAILogger", - "choose_precision", - "choose_torch_device", ] diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index cb6b93eaac..e8380dc8bc 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,89 +1,110 @@ -from __future__ import annotations - -from contextlib import nullcontext -from typing import Literal, Optional, Union +from typing import Dict, Literal, Optional, Union import torch -from torch import autocast +from deprecated import deprecated -from invokeai.app.services.config.config_default import PRECISION, get_config +from invokeai.app.services.config.config_default import get_config +# legacy APIs +TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] CPU_DEVICE = torch.device("cpu") CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") +@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore +def choose_precision(device: torch.device) -> TorchPrecisionNames: + """Return the string representation of the recommended torch device.""" + torch_dtype = TorchDevice.choose_torch_dtype(device) + return PRECISION_TO_NAME[torch_dtype] + + +@deprecated("Use TorchDevice.choose_torch_device() instead.") # type: ignore def choose_torch_device() -> torch.device: - """Convenience routine for guessing which GPU device to run model on""" - config = get_config() - if config.device == "auto": - if torch.cuda.is_available(): - return torch.device("cuda") - if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - return torch.device("mps") + """Return the torch.device to use for accelerated inference.""" + return TorchDevice.choose_torch_device() + + +@deprecated("Use TorchDevice.choose_torch_dtype() instead.") # type: ignore +def torch_dtype(device: torch.device) -> torch.dtype: + """Return the torch precision for the recommended torch device.""" + return TorchDevice.choose_torch_dtype(device) + + +NAME_TO_PRECISION: Dict[TorchPrecisionNames, torch.dtype] = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, +} +PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NAME_TO_PRECISION.items()} + + +class TorchDevice: + """Abstraction layer for torch devices.""" + + @classmethod + def choose_torch_device(cls) -> torch.device: + """Return the torch.device to use for accelerated inference.""" + app_config = get_config() + if app_config.device != "auto": + device = torch.device(app_config.device) + elif torch.cuda.is_available(): + device = CUDA_DEVICE + elif torch.backends.mps.is_available(): + device = MPS_DEVICE else: - return CPU_DEVICE - else: - return torch.device(config.device) + device = CPU_DEVICE + return cls.normalize(device) + @classmethod + def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype: + """Return the precision to use for accelerated inference.""" + device = device or cls.choose_torch_device() + config = get_config() + if device.type == "cuda" and torch.cuda.is_available(): + device_name = torch.cuda.get_device_name(device) + if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: + # These GPUs have limited support for float16 + return cls._to_dtype("float32") + elif config.precision == "auto": + # Default to float16 for CUDA devices + return cls._to_dtype("float16") + else: + # Use the user-defined precision + return cls._to_dtype(config.precision) -def get_torch_device_name() -> str: - device = choose_torch_device() - return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + elif device.type == "mps" and torch.backends.mps.is_available(): + if config.precision == "auto": + # Default to float16 for MPS devices + return cls._to_dtype("float16") + else: + # Use the user-defined precision + return cls._to_dtype(config.precision) + # CPU / safe fallback + return cls._to_dtype("float32") + @classmethod + def get_torch_device_name(cls) -> str: + """Return the device name for the current torch device.""" + device = cls.choose_torch_device() + return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() -def choose_precision(device: torch.device) -> Literal["float32", "float16", "bfloat16"]: - """Return an appropriate precision for the given torch device.""" - app_config = get_config() - if device.type == "cuda": - device_name = torch.cuda.get_device_name(device) - if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name: - # These GPUs have limited support for float16 - return "float32" - elif app_config.precision == "auto" or app_config.precision == "autocast": - # Default to float16 for CUDA devices - return "float16" - else: - # Use the user-defined precision - return app_config.precision - elif device.type == "mps": - if app_config.precision == "auto" or app_config.precision == "autocast": - # Default to float16 for MPS devices - return "float16" - else: - # Use the user-defined precision - return app_config.precision - # CPU / safe fallback - return "float32" - - -def torch_dtype(device: Optional[torch.device] = None) -> torch.dtype: - device = device or choose_torch_device() - precision = choose_precision(device) - if precision == "float16": - return torch.float16 - if precision == "bfloat16": - return torch.bfloat16 - else: - # "auto", "autocast", "float32" - return torch.float32 - - -def choose_autocast(precision: PRECISION): - """Returns an autocast context or nullcontext for the given precision string""" - # float16 currently requires autocast to avoid errors like: - # 'expected scalar type Half but found Float' - if precision == "autocast" or precision == "float16": - return autocast - return nullcontext - - -def normalize_device(device: Union[str, torch.device]) -> torch.device: - """Ensure device has a device index defined, if appropriate.""" - device = torch.device(device) - if device.index is None: - # cuda might be the only torch backend that currently uses the device index? - # I don't see anything like `current_device` for cpu or mps. - if device.type == "cuda": + @classmethod + def normalize(cls, device: Union[str, torch.device]) -> torch.device: + """Add the device index to CUDA devices.""" + device = torch.device(device) + if device.index is None and device.type == "cuda" and torch.cuda.is_available(): device = torch.device(device.type, torch.cuda.current_device()) - return device + return device + + @classmethod + def empty_cache(cls) -> None: + """Clear the GPU device cache.""" + if torch.backends.mps.is_available(): + torch.mps.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + @classmethod + def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype: + return NAME_TO_PRECISION[precision_name] diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py new file mode 100644 index 0000000000..8e810e4367 --- /dev/null +++ b/tests/backend/util/test_devices.py @@ -0,0 +1,132 @@ +""" +Test abstract device class. +""" + +from unittest.mock import patch + +import pytest +import torch + +from invokeai.app.services.config import get_config +from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype + +devices = ["cpu", "cuda:0", "cuda:1", "mps"] +device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)] +device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)] +device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)] + + +@pytest.mark.parametrize("device_name", devices) +def test_device_choice(device_name): + config = get_config() + config.device = device_name + torch_device = TorchDevice.choose_torch_device() + assert torch_device == torch.device(device_name) + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cpu) +def test_device_dtype_cpu(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cuda) +def test_device_dtype_cuda(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=True), + patch("torch.cuda.get_device_name", return_value="RTX4070"), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_mps) +def test_device_dtype_mps(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=True), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == dtype + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cuda) +def test_device_dtype_override(device_dtype_pair): + with ( + patch("torch.cuda.get_device_name", return_value="RTX4070"), + patch("torch.cuda.is_available", return_value=True), + patch("torch.backends.mps.is_available", return_value=False), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + config.precision = "float32" + torch_dtype = TorchDevice.choose_torch_dtype() + assert torch_dtype == torch.float32 + + +def test_normalize(): + assert ( + TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda") + ) + assert ( + TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda") + ) + assert ( + TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda") + ) + assert TorchDevice.normalize("mps") == torch.device("mps") + assert TorchDevice.normalize("cpu") == torch.device("cpu") + + +@pytest.mark.parametrize("device_name", devices) +def test_legacy_device_choice(device_name): + config = get_config() + config.device = device_name + with pytest.deprecated_call(): + torch_device = choose_torch_device() + assert torch_device == torch.device(device_name) + + +@pytest.mark.parametrize("device_dtype_pair", device_types_cpu) +def test_legacy_device_dtype_cpu(device_dtype_pair): + with ( + patch("torch.cuda.is_available", return_value=False), + patch("torch.backends.mps.is_available", return_value=False), + patch("torch.cuda.get_device_name", return_value="RTX9090"), + ): + device_name, dtype = device_dtype_pair + config = get_config() + config.device = device_name + with pytest.deprecated_call(): + torch_device = choose_torch_device() + returned_dtype = torch_dtype(torch_device) + assert returned_dtype == dtype + + +def test_legacy_precision_name(): + config = get_config() + config.precision = "auto" + with ( + pytest.deprecated_call(), + patch("torch.cuda.is_available", return_value=True), + patch("torch.backends.mps.is_available", return_value=True), + patch("torch.cuda.get_device_name", return_value="RTX9090"), + ): + assert "float16" == choose_precision(torch.device("cuda")) + assert "float16" == choose_precision(torch.device("mps")) + assert "float32" == choose_precision(torch.device("cpu"))