mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added more unit tests
This commit is contained in:
parent
eaadc55c7d
commit
763a2e2632
@ -8,6 +8,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config import get_config
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.backend.model_manager.load import ModelCache
|
||||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||||
|
|
||||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||||
@ -130,3 +131,32 @@ def test_legacy_precision_name():
|
|||||||
assert "float16" == choose_precision(torch.device("cuda"))
|
assert "float16" == choose_precision(torch.device("cuda"))
|
||||||
assert "float16" == choose_precision(torch.device("mps"))
|
assert "float16" == choose_precision(torch.device("mps"))
|
||||||
assert "float32" == choose_precision(torch.device("cpu"))
|
assert "float32" == choose_precision(torch.device("cpu"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_device_support_1():
|
||||||
|
config = get_config()
|
||||||
|
config.devices = ["cuda:0", "cuda:1"]
|
||||||
|
assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_device_support_2():
|
||||||
|
config = get_config()
|
||||||
|
config.devices = None
|
||||||
|
with (
|
||||||
|
patch("torch.cuda.device_count", return_value=3),
|
||||||
|
patch("torch.cuda.is_available", return_value=True),
|
||||||
|
):
|
||||||
|
assert TorchDevice.execution_devices() == {
|
||||||
|
torch.device("cuda:0"),
|
||||||
|
torch.device("cuda:1"),
|
||||||
|
torch.device("cuda:2"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_device_support_3():
|
||||||
|
config = get_config()
|
||||||
|
config.devices = ["cuda:0", "cuda:1"]
|
||||||
|
cache = ModelCache()
|
||||||
|
with cache.reserve_execution_device() as gpu:
|
||||||
|
assert gpu in [torch.device(x) for x in config.devices]
|
||||||
|
assert TorchDevice.choose_torch_device() == gpu
|
||||||
|
Loading…
Reference in New Issue
Block a user