fix merge issues; likely nonfunctional

This commit is contained in:
Lincoln Stein
2024-04-15 21:16:21 -04:00
214 changed files with 4032 additions and 2058 deletions

View File

@ -1,8 +1,8 @@
import pytest
import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample

View File

@ -0,0 +1,132 @@
"""
Test abstract device class.
"""
from unittest.mock import patch
import pytest
import torch
from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
@pytest.mark.parametrize("device_name", devices)
def test_device_choice(device_name):
config = get_config()
config.device = device_name
torch_device = TorchDevice.choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_cuda(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
def test_device_dtype_mps(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=True),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == dtype
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
def test_device_dtype_override(device_dtype_pair):
with (
patch("torch.cuda.get_device_name", return_value="RTX4070"),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=False),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
config.precision = "float32"
torch_dtype = TorchDevice.choose_torch_dtype()
assert torch_dtype == torch.float32
def test_normalize():
assert (
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
)
assert (
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
)
assert TorchDevice.normalize("mps") == torch.device("mps")
assert TorchDevice.normalize("cpu") == torch.device("cpu")
@pytest.mark.parametrize("device_name", devices)
def test_legacy_device_choice(device_name):
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
assert torch_device == torch.device(device_name)
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
def test_legacy_device_dtype_cpu(device_dtype_pair):
with (
patch("torch.cuda.is_available", return_value=False),
patch("torch.backends.mps.is_available", return_value=False),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
device_name, dtype = device_dtype_pair
config = get_config()
config.device = device_name
with pytest.deprecated_call():
torch_device = choose_torch_device()
returned_dtype = torch_dtype(torch_device)
assert returned_dtype == dtype
def test_legacy_precision_name():
config = get_config()
config.precision = "auto"
with (
pytest.deprecated_call(),
patch("torch.cuda.is_available", return_value=True),
patch("torch.backends.mps.is_available", return_value=True),
patch("torch.cuda.get_device_name", return_value="RTX9090"),
):
assert "float16" == choose_precision(torch.device("cuda"))
assert "float16" == choose_precision(torch.device("mps"))
assert "float32" == choose_precision(torch.device("cpu"))

View File

@ -0,0 +1,88 @@
import pytest
import torch
from invokeai.backend.util.mask import to_standard_float_mask
def test_to_standard_float_mask_wrong_ndim():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32)
def test_to_standard_float_mask_wrong_shape():
with pytest.raises(ValueError):
to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32)
def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor):
"""Helper function to check the result of `to_standard_float_mask()`."""
assert mask.shape == expected_mask.shape
assert mask.dtype == expected_mask.dtype
assert torch.allclose(mask, expected_mask)
def test_to_standard_float_mask_ndim_2():
"""Test the case where the input mask has shape (h, w)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 1.0
mask[1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
def test_to_standard_float_mask_ndim_3():
"""Test the case where the input mask has shape (1, h, w)."""
mask = torch.zeros((1, 3, 2), dtype=torch.float32)
mask[0, 0, 0] = 1.0
mask[0, 1, 1] = 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has dtype bool."""
mask = torch.zeros((3, 2), dtype=torch.bool)
mask[0, 0] = True
mask[1, 1] = True
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 0] = 1.0
expected_mask[0, 1, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)
@pytest.mark.parametrize(
"out_dtype",
[torch.float32, torch.float16],
)
def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype):
"""Test the case where the input mask has type float (but not all values are 0.0 or 1.0)."""
mask = torch.zeros((3, 2), dtype=torch.float32)
mask[0, 0] = 0.1 # Should be converted to 0.0
mask[0, 1] = 0.9 # Should be converted to 1.0
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
expected_mask[0, 0, 1] = 1.0
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
check_mask_result(mask=new_mask, expected_mask=expected_mask)