mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tests: redo config tests
This commit is contained in:
parent
53c8f36029
commit
5606f4d627
@ -1,11 +1,49 @@
|
|||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
|
||||||
|
|
||||||
|
v4_config = """
|
||||||
|
meta:
|
||||||
|
schema_version: 4
|
||||||
|
|
||||||
|
host: "192.168.1.1"
|
||||||
|
port: 8080
|
||||||
|
"""
|
||||||
|
|
||||||
|
invalid_v5_config = """
|
||||||
|
meta:
|
||||||
|
schema_version: 5
|
||||||
|
|
||||||
|
host: "192.168.1.1"
|
||||||
|
port: 8080
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
v3_config = """
|
||||||
|
InvokeAI:
|
||||||
|
Web Server:
|
||||||
|
host: 192.168.1.1
|
||||||
|
port: 8080
|
||||||
|
Features:
|
||||||
|
esrgan: true
|
||||||
|
internet_available: true
|
||||||
|
log_tokenization: false
|
||||||
|
patchmatch: true
|
||||||
|
ignore_missing_core_models: false
|
||||||
|
Paths:
|
||||||
|
outdir: /some/outputs/dir
|
||||||
|
"""
|
||||||
|
|
||||||
|
invalid_config = """
|
||||||
|
i like turtles
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||||
@ -13,151 +51,101 @@ def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
|||||||
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
||||||
|
|
||||||
|
|
||||||
init1 = OmegaConf.create(
|
def test_path_resolution_root_not_set():
|
||||||
"""
|
"""Test path resolutions when the root is not explicitly set."""
|
||||||
InvokeAI:
|
config = InvokeAIAppConfig()
|
||||||
Features:
|
expected_root = InvokeAIAppConfig.find_root()
|
||||||
always_use_cpu: false
|
assert config.root_path == expected_root
|
||||||
Model Cache:
|
|
||||||
convert_cache: 5
|
|
||||||
Generation:
|
|
||||||
force_tiled_decode: false
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
init2 = OmegaConf.create(
|
|
||||||
"""
|
|
||||||
InvokeAI:
|
|
||||||
Features:
|
|
||||||
always_use_cpu: true
|
|
||||||
Model Cache:
|
|
||||||
convert_cache: 2
|
|
||||||
Generation:
|
|
||||||
force_tiled_decode: true
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
init3 = OmegaConf.create(
|
|
||||||
"""
|
|
||||||
InvokeAI:
|
|
||||||
Generation:
|
|
||||||
sequential_guidance: true
|
|
||||||
attention_type: xformers
|
|
||||||
attention_slice_size: 7
|
|
||||||
forced_tiled_decode: True
|
|
||||||
Device:
|
|
||||||
device: cpu
|
|
||||||
Model Cache:
|
|
||||||
ram: 1.25
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_use_init(patch_rootdir):
|
def test_read_config_from_file(tmp_path: Path):
|
||||||
# note that we explicitly set omegaconf dict and argv here
|
"""Test reading configuration from a file."""
|
||||||
# so that the values aren't read from ~invokeai/invokeai.yaml and
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
# sys.argv respectively.
|
temp_config_file.write_text(v4_config)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
conf1 = InvokeAIAppConfig.get_config()
|
config = load_and_migrate_config(temp_config_file)
|
||||||
assert conf1
|
assert config.host == "192.168.1.1"
|
||||||
conf1.parse_args(conf=init1, argv=[])
|
assert config.port == 8080
|
||||||
assert not conf1.force_tiled_decode
|
|
||||||
assert conf1.convert_cache == 5
|
|
||||||
assert not conf1.always_use_cpu
|
|
||||||
|
|
||||||
conf2 = InvokeAIAppConfig.get_config()
|
|
||||||
assert conf2
|
|
||||||
conf2.parse_args(conf=init2, argv=[])
|
|
||||||
assert conf2.force_tiled_decode
|
|
||||||
assert conf2.convert_cache == 2
|
|
||||||
assert not hasattr(conf2, "invalid_attribute")
|
|
||||||
|
|
||||||
|
|
||||||
def test_legacy():
|
def test_migrate_v3_config_from_file(tmp_path: Path):
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
"""Test reading configuration from a file."""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(v3_config)
|
||||||
|
|
||||||
conf = InvokeAIAppConfig.get_config()
|
config = load_and_migrate_config(temp_config_file)
|
||||||
assert conf
|
assert config.outputs_dir == Path("/some/outputs/dir")
|
||||||
conf.parse_args(conf=init3, argv=[])
|
assert config.host == "192.168.1.1"
|
||||||
assert conf.xformers_enabled
|
assert config.port == 8080
|
||||||
assert conf.device == "cpu"
|
# This should be stripped out
|
||||||
assert conf.use_cpu
|
assert not hasattr(config, "esrgan")
|
||||||
assert conf.ram == 1.25
|
|
||||||
assert conf.ram_cache_size == 1.25
|
|
||||||
|
|
||||||
|
|
||||||
def test_argv_override():
|
def test_bails_on_invalid_config(tmp_path: Path):
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
"""Test reading configuration from a file."""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(invalid_config)
|
||||||
|
|
||||||
conf = InvokeAIAppConfig.get_config()
|
with pytest.raises(AssertionError):
|
||||||
conf.parse_args(conf=init1, argv=["--always_use_cpu", "--max_cache=10"])
|
load_and_migrate_config(temp_config_file)
|
||||||
assert conf.always_use_cpu
|
|
||||||
assert conf.max_cache_size == 10
|
|
||||||
assert conf.outdir == Path("outputs") # this is the default
|
|
||||||
|
|
||||||
|
|
||||||
def test_env_override(patch_rootdir):
|
def test_bails_on_config_with_unsupported_version(tmp_path: Path):
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
"""Test reading configuration from a file."""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(invalid_v5_config)
|
||||||
|
|
||||||
# argv overrides
|
with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||||
conf = InvokeAIAppConfig()
|
load_and_migrate_config(temp_config_file)
|
||||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
|
||||||
assert conf.always_use_cpu is False
|
|
||||||
os.environ["INVOKEAI_always_use_cpu"] = "True"
|
|
||||||
conf.parse_args(conf=init1, argv=["--max_cache=10"])
|
|
||||||
assert conf.always_use_cpu is True
|
|
||||||
|
|
||||||
# environment variables should be case insensitive
|
|
||||||
os.environ["InvokeAI_Max_Cache_Size"] = "15"
|
|
||||||
conf = InvokeAIAppConfig()
|
|
||||||
conf.parse_args(conf=init1, argv=[])
|
|
||||||
assert conf.max_cache_size == 15
|
|
||||||
|
|
||||||
conf = InvokeAIAppConfig()
|
|
||||||
conf.parse_args(conf=init1, argv=["--no-always_use_cpu", "--max_cache=10"])
|
|
||||||
assert conf.always_use_cpu is False
|
|
||||||
assert conf.max_cache_size == 10
|
|
||||||
|
|
||||||
conf = InvokeAIAppConfig.get_config(max_cache_size=20)
|
|
||||||
conf.parse_args(conf=init1, argv=[])
|
|
||||||
assert conf.max_cache_size == 20
|
|
||||||
|
|
||||||
# make sure that prefix is respected
|
|
||||||
del os.environ["INVOKEAI_always_use_cpu"]
|
|
||||||
os.environ["always_use_cpu"] = "True"
|
|
||||||
conf.parse_args(conf=init1, argv=[])
|
|
||||||
assert conf.always_use_cpu is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_root_resists_cwd(patch_rootdir):
|
def test_write_config_to_file():
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
"""Test writing configuration to a file, checking for correct output."""
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
temp_config_path = Path(tmpdir) / "invokeai.yaml"
|
||||||
|
config = InvokeAIAppConfig(host="192.168.1.1", port=8080)
|
||||||
|
config.set_root(Path(tmpdir))
|
||||||
|
config.write_file(exclude_defaults=False)
|
||||||
|
|
||||||
previous = os.environ["INVOKEAI_ROOT"]
|
# Load the file and check contents
|
||||||
cwd = Path(os.getcwd()).resolve()
|
with open(temp_config_path, "r") as file:
|
||||||
|
content = file.read()
|
||||||
os.environ["INVOKEAI_ROOT"] = "."
|
assert "host: 192.168.1.1" in content
|
||||||
conf = InvokeAIAppConfig.get_config()
|
assert "port: 8080" in content
|
||||||
conf.parse_args([])
|
|
||||||
assert conf.root_path == cwd
|
|
||||||
|
|
||||||
os.chdir("..")
|
|
||||||
assert conf.root_path == cwd
|
|
||||||
os.environ["INVOKEAI_ROOT"] = previous
|
|
||||||
os.chdir(cwd)
|
|
||||||
|
|
||||||
|
|
||||||
def test_type_coercion(patch_rootdir):
|
def test_update_config_with_dict():
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
"""Test updating the config with a dictionary."""
|
||||||
|
config = InvokeAIAppConfig()
|
||||||
|
update_dict = {"host": "10.10.10.10", "port": 6060}
|
||||||
|
config.update_config(update_dict)
|
||||||
|
assert config.host == "10.10.10.10"
|
||||||
|
assert config.port == 6060
|
||||||
|
|
||||||
conf = InvokeAIAppConfig().get_config()
|
|
||||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
def test_update_config_with_object():
|
||||||
assert conf.root == Path("/tmp/foobar")
|
"""Test updating the config with another config object."""
|
||||||
assert isinstance(conf.root, Path)
|
config = InvokeAIAppConfig()
|
||||||
conf = InvokeAIAppConfig.get_config(root="/tmp/different")
|
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
|
||||||
conf.parse_args(argv=["--root=/tmp/foobar"])
|
config.update_config(new_config)
|
||||||
assert conf.root == Path("/tmp/different")
|
assert config.host == "10.10.10.10"
|
||||||
assert isinstance(conf.root, Path)
|
assert config.port == 6060
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_and_resolve_paths():
|
||||||
|
"""Test setting root and resolving paths based on it."""
|
||||||
|
with TemporaryDirectory() as tmpdir:
|
||||||
|
config = InvokeAIAppConfig()
|
||||||
|
config.set_root(Path(tmpdir))
|
||||||
|
assert config.models_path == Path(tmpdir) / "models"
|
||||||
|
assert config.db_path == Path(tmpdir) / "databases" / "invokeai.db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_singleton_behavior():
|
||||||
|
"""Test that get_config always returns the same instance."""
|
||||||
|
config1 = get_config()
|
||||||
|
config2 = get_config()
|
||||||
|
assert config1 is config2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(
|
@pytest.mark.xfail(
|
||||||
@ -171,13 +159,11 @@ def test_type_coercion(patch_rootdir):
|
|||||||
|
|
||||||
This test passes when `test_config.py` is tested in isolation.
|
This test passes when `test_config.py` is tested in isolation.
|
||||||
|
|
||||||
Perhaps a solution would be to call `InvokeAIAppConfig.get_config().parse_args()` in
|
Perhaps a solution would be to call `get_app_config().parse_args()` in
|
||||||
other test files?
|
other test files?
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
def test_deny_nodes(patch_rootdir):
|
def test_deny_nodes(patch_rootdir):
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
|
|
||||||
# Allow integer, string and float, but explicitly deny float
|
# Allow integer, string and float, but explicitly deny float
|
||||||
allow_deny_nodes_conf = OmegaConf.create(
|
allow_deny_nodes_conf = OmegaConf.create(
|
||||||
"""
|
"""
|
||||||
@ -192,8 +178,8 @@ def test_deny_nodes(patch_rootdir):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
# must parse config before importing Graph, so its nodes union uses the config
|
# must parse config before importing Graph, so its nodes union uses the config
|
||||||
conf = InvokeAIAppConfig().get_config()
|
conf = get_config()
|
||||||
conf.parse_args(conf=allow_deny_nodes_conf, argv=[])
|
conf.read_config(conf=allow_deny_nodes_conf, argv=[])
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.graph import Graph
|
||||||
|
|
||||||
# confirm graph validation fails when using denied node
|
# confirm graph validation fails when using denied node
|
||||||
|
Loading…
Reference in New Issue
Block a user