diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 8e810e4367..f4faea5d98 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -8,6 +8,7 @@ import pytest import torch from invokeai.app.services.config import get_config +from invokeai.backend.model_manager.load import ModelCache from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype devices = ["cpu", "cuda:0", "cuda:1", "mps"] @@ -130,3 +131,32 @@ def test_legacy_precision_name(): assert "float16" == choose_precision(torch.device("cuda")) assert "float16" == choose_precision(torch.device("mps")) assert "float32" == choose_precision(torch.device("cpu")) + + +def test_multi_device_support_1(): + config = get_config() + config.devices = ["cuda:0", "cuda:1"] + assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")} + + +def test_multi_device_support_2(): + config = get_config() + config.devices = None + with ( + patch("torch.cuda.device_count", return_value=3), + patch("torch.cuda.is_available", return_value=True), + ): + assert TorchDevice.execution_devices() == { + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + } + + +def test_multi_device_support_3(): + config = get_config() + config.devices = ["cuda:0", "cuda:1"] + cache = ModelCache() + with cache.reserve_execution_device() as gpu: + assert gpu in [torch.device(x) for x in config.devices] + assert TorchDevice.choose_torch_device() == gpu