mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix merge issues; likely nonfunctional
This commit is contained in:
@ -87,9 +87,11 @@ def test_rename(
|
||||
key = mm2_installer.install_path(embedding_file)
|
||||
model_record = store.get_model(key)
|
||||
assert model_record.path.endswith("sd-1/embedding/test_embedding.safetensors")
|
||||
store.update_model(key, ModelRecordChanges(name="new_name.safetensors", base=BaseModelType("sd-2")))
|
||||
store.update_model(key, ModelRecordChanges(name="new model name", base=BaseModelType("sd-2")))
|
||||
new_model_record = mm2_installer.sync_model_path(key)
|
||||
assert new_model_record.path.endswith("sd-2/embedding/new_name.safetensors")
|
||||
# Renaming the model record shouldn't rename the file
|
||||
assert new_model_record.name == "new model name"
|
||||
assert new_model_record.path.endswith("sd-2/embedding/test_embedding.safetensors")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,8 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
||||
@ -77,7 +77,7 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
|
||||
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
|
||||
|
||||
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
|
||||
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
|
||||
ip_adapter_unet_patcher = UNetAttentionPatcher([ip_adapter])
|
||||
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
|
||||
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample
|
||||
|
||||
|
132
tests/backend/util/test_devices.py
Normal file
132
tests/backend/util/test_devices.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""
|
||||
Test abstract device class.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||
|
||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||
device_types_cuda = [("cpu", torch.float32), ("cuda:0", torch.float16), ("mps", torch.float32)]
|
||||
device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float16)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_name", devices)
|
||||
def test_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||
def test_device_dtype_cpu(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||
def test_device_dtype_cuda(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_mps)
|
||||
def test_device_dtype_mps(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=True),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == dtype
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cuda)
|
||||
def test_device_dtype_override(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.get_device_name", return_value="RTX4070"),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
config.precision = "float32"
|
||||
torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
assert torch_dtype == torch.float32
|
||||
|
||||
|
||||
def test_normalize():
|
||||
assert (
|
||||
TorchDevice.normalize("cuda") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert (
|
||||
TorchDevice.normalize("cuda:0") == torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert (
|
||||
TorchDevice.normalize("cuda:1") == torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cuda")
|
||||
)
|
||||
assert TorchDevice.normalize("mps") == torch.device("mps")
|
||||
assert TorchDevice.normalize("cpu") == torch.device("cpu")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_name", devices)
|
||||
def test_legacy_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
with pytest.deprecated_call():
|
||||
torch_device = choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device_dtype_pair", device_types_cpu)
|
||||
def test_legacy_device_dtype_cpu(device_dtype_pair):
|
||||
with (
|
||||
patch("torch.cuda.is_available", return_value=False),
|
||||
patch("torch.backends.mps.is_available", return_value=False),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||
):
|
||||
device_name, dtype = device_dtype_pair
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
with pytest.deprecated_call():
|
||||
torch_device = choose_torch_device()
|
||||
returned_dtype = torch_dtype(torch_device)
|
||||
assert returned_dtype == dtype
|
||||
|
||||
|
||||
def test_legacy_precision_name():
|
||||
config = get_config()
|
||||
config.precision = "auto"
|
||||
with (
|
||||
pytest.deprecated_call(),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
patch("torch.backends.mps.is_available", return_value=True),
|
||||
patch("torch.cuda.get_device_name", return_value="RTX9090"),
|
||||
):
|
||||
assert "float16" == choose_precision(torch.device("cuda"))
|
||||
assert "float16" == choose_precision(torch.device("mps"))
|
||||
assert "float32" == choose_precision(torch.device("cpu"))
|
88
tests/backend/util/test_mask.py
Normal file
88
tests/backend/util/test_mask.py
Normal file
@ -0,0 +1,88 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
|
||||
|
||||
def test_to_standard_float_mask_wrong_ndim():
|
||||
with pytest.raises(ValueError):
|
||||
to_standard_float_mask(mask=torch.zeros((1, 1, 5, 10)), out_dtype=torch.float32)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_wrong_shape():
|
||||
with pytest.raises(ValueError):
|
||||
to_standard_float_mask(mask=torch.zeros((2, 5, 10)), out_dtype=torch.float32)
|
||||
|
||||
|
||||
def check_mask_result(mask: torch.Tensor, expected_mask: torch.Tensor):
|
||||
"""Helper function to check the result of `to_standard_float_mask()`."""
|
||||
assert mask.shape == expected_mask.shape
|
||||
assert mask.dtype == expected_mask.dtype
|
||||
assert torch.allclose(mask, expected_mask)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_ndim_2():
|
||||
"""Test the case where the input mask has shape (h, w)."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.float32)
|
||||
mask[0, 0] = 1.0
|
||||
mask[1, 1] = 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
def test_to_standard_float_mask_ndim_3():
|
||||
"""Test the case where the input mask has shape (1, h, w)."""
|
||||
mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
mask[0, 0, 0] = 1.0
|
||||
mask[0, 1, 1] = 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=torch.float32)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=torch.float32)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"out_dtype",
|
||||
[torch.float32, torch.float16],
|
||||
)
|
||||
def test_to_standard_float_mask_bool_to_float(out_dtype: torch.dtype):
|
||||
"""Test the case where the input mask has dtype bool."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.bool)
|
||||
mask[0, 0] = True
|
||||
mask[1, 1] = True
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
|
||||
expected_mask[0, 0, 0] = 1.0
|
||||
expected_mask[0, 1, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"out_dtype",
|
||||
[torch.float32, torch.float16],
|
||||
)
|
||||
def test_to_standard_float_mask_float_to_float(out_dtype: torch.dtype):
|
||||
"""Test the case where the input mask has type float (but not all values are 0.0 or 1.0)."""
|
||||
mask = torch.zeros((3, 2), dtype=torch.float32)
|
||||
mask[0, 0] = 0.1 # Should be converted to 0.0
|
||||
mask[0, 1] = 0.9 # Should be converted to 1.0
|
||||
|
||||
expected_mask = torch.zeros((1, 3, 2), dtype=out_dtype)
|
||||
expected_mask[0, 0, 1] = 1.0
|
||||
|
||||
new_mask = to_standard_float_mask(mask=mask, out_dtype=out_dtype)
|
||||
|
||||
check_mask_result(mask=new_mask, expected_mask=expected_mask)
|
@ -97,6 +97,32 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"legacy_conf_dir,expected_value,expected_is_set",
|
||||
[
|
||||
# not set, expected value is the default value
|
||||
("configs/stable-diffusion", Path("configs"), False),
|
||||
# not set, expected value is the default value
|
||||
("configs\\stable-diffusion", Path("configs"), False),
|
||||
# set, best-effort resolution of the path
|
||||
("partial_custom_path/stable-diffusion", Path("partial_custom_path"), True),
|
||||
# set, exact path
|
||||
("full/custom/path", Path("full/custom/path"), True),
|
||||
],
|
||||
)
|
||||
def test_migrate_v3_legacy_conf_dir_defaults(
|
||||
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
|
||||
):
|
||||
"""Test reading configuration from a file."""
|
||||
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
temp_config_file.write_text(config_content)
|
||||
|
||||
config = load_and_migrate_config(temp_config_file)
|
||||
assert config.legacy_conf_dir == expected_value
|
||||
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
|
||||
|
||||
|
||||
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
||||
"""Test the backup of the config file."""
|
||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||
|
@ -250,6 +250,32 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
|
||||
db.conn.close()
|
||||
|
||||
|
||||
def test_migrator_backs_up_db(logger: Logger) -> None:
|
||||
with TemporaryDirectory() as tempdir:
|
||||
original_db_path = Path(tempdir) / "invokeai.db"
|
||||
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
|
||||
# Write some data to the db to test for successful backup
|
||||
temp_cursor = db.conn.cursor()
|
||||
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.conn.commit()
|
||||
# Set up the migrator
|
||||
migrator = SqliteMigrator(db=db)
|
||||
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
|
||||
for migration in migrations:
|
||||
migrator.register_migration(migration)
|
||||
migrator.run_migrations()
|
||||
# Must manually close else we get an error on Windows
|
||||
db.conn.close()
|
||||
assert original_db_path.exists()
|
||||
# We should have a backup file when we migrated a file db
|
||||
assert migrator._backup_path
|
||||
# Check that the test table exists as a proxy for successful backup
|
||||
with closing(sqlite3.connect(migrator._backup_path)) as backup_db_conn:
|
||||
backup_db_cursor = backup_db_conn.cursor()
|
||||
backup_db_cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test';")
|
||||
assert backup_db_cursor.fetchone() is not None
|
||||
|
||||
|
||||
def test_migrator_makes_no_changes_on_failed_migration(
|
||||
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user