tests: update config tests

- Add patched rootdir fixture to all config tests. I think this isn't strictly necessary but it does ensure that any config tests that need to write files don't accidentally write to user data locations.
- Be more careful when calling `get_config()` in the tests, by clearing the LRU cache before and after. This ensures a test doesn't reference the singleton config created by a previously run test.
- Add test for env var parsing.
- Add test for config writing in the context of `get_config()`. This is effectively a mini e2e test for the config lifecycle.
This commit is contained in:
psychedelicious 2024-03-21 12:00:40 +11:00
parent f538ed54fb
commit 842b57e57c

View File

@ -6,7 +6,13 @@ import pytest
from omegaconf import OmegaConf
from pydantic import ValidationError
from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config, load_and_migrate_config
from invokeai.app.services.config.config_default import (
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
get_config,
load_and_migrate_config,
)
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
v4_config = """
schema_version: 4.0.0
@ -59,14 +65,14 @@ def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
def test_path_resolution_root_not_set():
def test_path_resolution_root_not_set(patch_rootdir: None):
"""Test path resolutions when the root is not explicitly set."""
config = InvokeAIAppConfig()
expected_root = InvokeAIAppConfig.find_root()
assert config.root_path == expected_root
def test_read_config_from_file(tmp_path: Path):
def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v4_config)
@ -76,7 +82,7 @@ def test_read_config_from_file(tmp_path: Path):
assert config.port == 8080
def test_migrate_v3_config_from_file(tmp_path: Path):
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
@ -92,7 +98,7 @@ def test_migrate_v3_config_from_file(tmp_path: Path):
assert not hasattr(config, "esrgan")
def test_migrate_v3_backup(tmp_path: Path):
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
@ -102,7 +108,7 @@ def test_migrate_v3_backup(tmp_path: Path):
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_failed_migrate_backup(tmp_path: Path):
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values)
@ -115,7 +121,7 @@ def test_failed_migrate_backup(tmp_path: Path):
assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path):
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_config)
@ -124,7 +130,7 @@ def test_bails_on_invalid_config(tmp_path: Path):
load_and_migrate_config(temp_config_file)
def test_bails_on_config_with_unsupported_version(tmp_path: Path):
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_v5_config)
@ -133,7 +139,7 @@ def test_bails_on_config_with_unsupported_version(tmp_path: Path):
load_and_migrate_config(temp_config_file)
def test_write_config_to_file():
def test_write_config_to_file(patch_rootdir: None):
"""Test writing configuration to a file, checking for correct output."""
with TemporaryDirectory() as tmpdir:
temp_config_path = Path(tmpdir) / "invokeai.yaml"
@ -148,7 +154,7 @@ def test_write_config_to_file():
assert "port: 8080" in content
def test_update_config_with_dict():
def test_update_config_with_dict(patch_rootdir: None):
"""Test updating the config with a dictionary."""
config = InvokeAIAppConfig()
update_dict = {"host": "10.10.10.10", "port": 6060}
@ -157,7 +163,7 @@ def test_update_config_with_dict():
assert config.port == 6060
def test_update_config_with_object():
def test_update_config_with_object(patch_rootdir: None):
"""Test updating the config with another config object."""
config = InvokeAIAppConfig()
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
@ -166,7 +172,7 @@ def test_update_config_with_object():
assert config.port == 6060
def test_set_and_resolve_paths():
def test_set_and_resolve_paths(patch_rootdir: None):
"""Test setting root and resolving paths based on it."""
with TemporaryDirectory() as tmpdir:
config = InvokeAIAppConfig()
@ -175,11 +181,62 @@ def test_set_and_resolve_paths():
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
def test_singleton_behavior():
def test_singleton_behavior(patch_rootdir: None):
"""Test that get_config always returns the same instance."""
get_config.cache_clear()
config1 = get_config()
config2 = get_config()
assert config1 is config2
get_config.cache_clear()
def test_default_config(patch_rootdir: None):
"""Test that the default config is as expected."""
config = DefaultInvokeAIAppConfig()
assert config.host == "127.0.0.1"
def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that environment variables are merged into the config"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
monkeypatch.setenv("INVOKEAI_PORT", "1234")
config = InvokeAIAppConfig()
assert config.host == "1.2.3.4"
assert config.port == 1234
assert config.root_path == tmp_path
def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that get_config writes the appropriate files to disk"""
# Trick the config into thinking it has already parsed args - this triggers the writing of the config file
InvokeAIArgs.did_parse = True
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
get_config.cache_clear()
config = get_config()
get_config.cache_clear()
config_file_path = tmp_path / "invokeai.yaml"
example_file_path = config_file_path.with_suffix(".example.yaml")
assert config.config_file_path == config_file_path
assert config_file_path.exists()
assert example_file_path.exists()
# The example file should have the default values
example_file_content = example_file_path.read_text()
assert "host: 127.0.0.1" in example_file_content
assert "port: 9090" in example_file_content
# It should also have the `remote_api_tokens` key
assert "remote_api_tokens" in example_file_content
# Neither env vars nor default values should be written to the config file
config_file_content = config_file_path.read_text()
assert "host" not in config_file_content
# Undo our change to the singleton class
InvokeAIArgs.did_parse = False
@pytest.mark.xfail(
@ -212,7 +269,9 @@ def test_deny_nodes(patch_rootdir):
"""
)
# must parse config before importing Graph, so its nodes union uses the config
get_config.cache_clear()
conf = get_config()
get_config.cache_clear()
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
from invokeai.app.services.shared.graph import Graph