InvokeAI/tests/backend/util/test_devices.py
Lincoln Stein e93f4d632d
[util] Add generic torch device class (#6174)
* introduce new abstraction layer for GPU devices

* add unit test for device abstraction

* fix ruff

* convert TorchDeviceSelect into a stateless class

* move logic to select context-specific execution device into context API

* add mock hardware environments to pytest

* remove dangling mocker fixture

* fix unit test for running on non-CUDA systems

* remove unimplemented get_execution_device() call

* remove autocast precision

* Multiple changes:

1. Remove TorchDeviceSelect.get_execution_device(), as well as calls to
   context.models.get_execution_device().
2. Rename TorchDeviceSelect to TorchDevice
3. Added back the legacy public API defined in `invocation_api`, including
   choose_precision().
4. Added a config file migration script to accommodate removal of precision=autocast.

* add deprecation warnings to choose_torch_device() and choose_precision()

* fix test crash

* remove app_config argument from choose_torch_device() and choose_torch_dtype()

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
2024-04-15 13:12:49 +00:00

133 lines
4.8 KiB
Python

"""
Test abstract device class.
"""
from unittest.mock import patch
import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
@pytest.mark.parametrize("device_name", devices)
def test_device_choice(device_name):
config = get_config()
config.device = device_name
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_cuda(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
def test_device_dtype_mps(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=True),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_override(device_dtype_pair):
with (
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
config.precision = "float32"
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == torch.float32
def test_normalize():
assert (
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
)
assert TorchDevice.normalize("mps") == torch.device("mps")
assert TorchDevice.normalize("cpu") == torch.device("cpu")
@pytest.mark.parametrize("device_name", devices)
def test_legacy_device_choice(device_name):
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_legacy_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
returned_dtype = torch_dtype(torch_device)
assert returned_dtype == dtype
def test_legacy_precision_name():
config = get_config()
config.precision = "auto"
with (
pytest.deprecated_call(),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))