mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
842b57e57c
- Add patched rootdir fixture to all config tests. I think this isn't strictly necessary but it does ensure that any config tests that need to write files don't accidentally write to user data locations. - Be more careful when calling `get_config()` in the tests, by clearing the LRU cache before and after. This ensures a test doesn't reference the singleton config created by a previously run test. - Add test for env var parsing. - Add test for config writing in the context of `get_config()`. This is effectively a mini e2e test for the config lifecycle.
297 lines
10 KiB
Python
297 lines
10 KiB
Python
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from omegaconf import OmegaConf
|
|
from pydantic import ValidationError
|
|
|
|
from invokeai.app.services.config.config_default import (
|
|
DefaultInvokeAIAppConfig,
|
|
InvokeAIAppConfig,
|
|
get_config,
|
|
load_and_migrate_config,
|
|
)
|
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
|
|
|
v4_config = """
|
|
schema_version: 4.0.0
|
|
|
|
host: "192.168.1.1"
|
|
port: 8080
|
|
"""
|
|
|
|
invalid_v5_config = """
|
|
schema_version: 5.0.0
|
|
|
|
host: "192.168.1.1"
|
|
port: 8080
|
|
"""
|
|
|
|
|
|
v3_config = """
|
|
InvokeAI:
|
|
Web Server:
|
|
host: 192.168.1.1
|
|
port: 8080
|
|
Features:
|
|
esrgan: true
|
|
internet_available: true
|
|
log_tokenization: false
|
|
patchmatch: true
|
|
ignore_missing_core_models: false
|
|
Paths:
|
|
outdir: /some/outputs/dir
|
|
conf_path: /custom/models.yaml
|
|
Model Cache:
|
|
max_cache_size: 100
|
|
max_vram_cache_size: 50
|
|
"""
|
|
|
|
v3_config_with_bad_values = """
|
|
InvokeAI:
|
|
Web Server:
|
|
port: "ice cream"
|
|
"""
|
|
|
|
invalid_config = """
|
|
i like turtles
|
|
"""
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
|
"""This may be overkill since the current tests don't need the root dir to exist"""
|
|
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
|
|
|
|
|
def test_path_resolution_root_not_set(patch_rootdir: None):
|
|
"""Test path resolutions when the root is not explicitly set."""
|
|
config = InvokeAIAppConfig()
|
|
expected_root = InvokeAIAppConfig.find_root()
|
|
assert config.root_path == expected_root
|
|
|
|
|
|
def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
|
|
"""Test reading configuration from a file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(v4_config)
|
|
|
|
config = load_and_migrate_config(temp_config_file)
|
|
assert config.host == "192.168.1.1"
|
|
assert config.port == 8080
|
|
|
|
|
|
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
|
"""Test reading configuration from a file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(v3_config)
|
|
|
|
config = load_and_migrate_config(temp_config_file)
|
|
assert config.outputs_dir == Path("/some/outputs/dir")
|
|
assert config.host == "192.168.1.1"
|
|
assert config.port == 8080
|
|
assert config.ram == 100
|
|
assert config.vram == 50
|
|
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
|
|
# This should be stripped out
|
|
assert not hasattr(config, "esrgan")
|
|
|
|
|
|
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
|
"""Test the backup of the config file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(v3_config)
|
|
|
|
load_and_migrate_config(temp_config_file)
|
|
assert temp_config_file.with_suffix(".yaml.bak").exists()
|
|
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
|
|
|
|
|
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
|
"""Test the failed migration of the config file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(v3_config_with_bad_values)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
load_and_migrate_config(temp_config_file)
|
|
assert temp_config_file.with_suffix(".yaml.bak").exists()
|
|
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config_with_bad_values
|
|
assert temp_config_file.exists()
|
|
assert temp_config_file.read_text() == v3_config_with_bad_values
|
|
|
|
|
|
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
|
|
"""Test reading configuration from a file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(invalid_config)
|
|
|
|
with pytest.raises(AssertionError):
|
|
load_and_migrate_config(temp_config_file)
|
|
|
|
|
|
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None):
|
|
"""Test reading configuration from a file."""
|
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
|
temp_config_file.write_text(invalid_v5_config)
|
|
|
|
with pytest.raises(RuntimeError, match="Invalid schema version"):
|
|
load_and_migrate_config(temp_config_file)
|
|
|
|
|
|
def test_write_config_to_file(patch_rootdir: None):
|
|
"""Test writing configuration to a file, checking for correct output."""
|
|
with TemporaryDirectory() as tmpdir:
|
|
temp_config_path = Path(tmpdir) / "invokeai.yaml"
|
|
config = InvokeAIAppConfig(host="192.168.1.1", port=8080)
|
|
config.write_file(temp_config_path)
|
|
# Load the file and check contents
|
|
with open(temp_config_path, "r") as file:
|
|
content = file.read()
|
|
# This is a default value, so it should not be in the file
|
|
assert "pil_compress_level" not in content
|
|
assert "host: 192.168.1.1" in content
|
|
assert "port: 8080" in content
|
|
|
|
|
|
def test_update_config_with_dict(patch_rootdir: None):
|
|
"""Test updating the config with a dictionary."""
|
|
config = InvokeAIAppConfig()
|
|
update_dict = {"host": "10.10.10.10", "port": 6060}
|
|
config.update_config(update_dict)
|
|
assert config.host == "10.10.10.10"
|
|
assert config.port == 6060
|
|
|
|
|
|
def test_update_config_with_object(patch_rootdir: None):
|
|
"""Test updating the config with another config object."""
|
|
config = InvokeAIAppConfig()
|
|
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
|
|
config.update_config(new_config)
|
|
assert config.host == "10.10.10.10"
|
|
assert config.port == 6060
|
|
|
|
|
|
def test_set_and_resolve_paths(patch_rootdir: None):
|
|
"""Test setting root and resolving paths based on it."""
|
|
with TemporaryDirectory() as tmpdir:
|
|
config = InvokeAIAppConfig()
|
|
config._root = Path(tmpdir)
|
|
assert config.models_path == Path(tmpdir).resolve() / "models"
|
|
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
|
|
|
|
|
|
def test_singleton_behavior(patch_rootdir: None):
|
|
"""Test that get_config always returns the same instance."""
|
|
get_config.cache_clear()
|
|
config1 = get_config()
|
|
config2 = get_config()
|
|
assert config1 is config2
|
|
get_config.cache_clear()
|
|
|
|
|
|
def test_default_config(patch_rootdir: None):
|
|
"""Test that the default config is as expected."""
|
|
config = DefaultInvokeAIAppConfig()
|
|
assert config.host == "127.0.0.1"
|
|
|
|
|
|
def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
|
"""Test that environment variables are merged into the config"""
|
|
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
|
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
|
|
monkeypatch.setenv("INVOKEAI_PORT", "1234")
|
|
config = InvokeAIAppConfig()
|
|
assert config.host == "1.2.3.4"
|
|
assert config.port == 1234
|
|
assert config.root_path == tmp_path
|
|
|
|
|
|
def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
|
"""Test that get_config writes the appropriate files to disk"""
|
|
# Trick the config into thinking it has already parsed args - this triggers the writing of the config file
|
|
InvokeAIArgs.did_parse = True
|
|
|
|
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
|
|
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
|
|
get_config.cache_clear()
|
|
config = get_config()
|
|
get_config.cache_clear()
|
|
config_file_path = tmp_path / "invokeai.yaml"
|
|
example_file_path = config_file_path.with_suffix(".example.yaml")
|
|
assert config.config_file_path == config_file_path
|
|
assert config_file_path.exists()
|
|
assert example_file_path.exists()
|
|
|
|
# The example file should have the default values
|
|
example_file_content = example_file_path.read_text()
|
|
assert "host: 127.0.0.1" in example_file_content
|
|
assert "port: 9090" in example_file_content
|
|
|
|
# It should also have the `remote_api_tokens` key
|
|
assert "remote_api_tokens" in example_file_content
|
|
|
|
# Neither env vars nor default values should be written to the config file
|
|
config_file_content = config_file_path.read_text()
|
|
assert "host" not in config_file_content
|
|
|
|
# Undo our change to the singleton class
|
|
InvokeAIArgs.did_parse = False
|
|
|
|
|
|
@pytest.mark.xfail(
|
|
reason="""
|
|
This test fails when run as part of the full test suite.
|
|
|
|
This test needs to deny nodes from being included in the InvocationsUnion by providing
|
|
an app configuration as a test fixture. Pytest executes all test files before running
|
|
tests, so the app configuration is already initialized by the time this test runs, and
|
|
the InvocationUnion is already created and the denied nodes are not omitted from it.
|
|
|
|
This test passes when `test_config.py` is tested in isolation.
|
|
|
|
Perhaps a solution would be to call `get_app_config().parse_args()` in
|
|
other test files?
|
|
"""
|
|
)
|
|
def test_deny_nodes(patch_rootdir):
|
|
# Allow integer, string and float, but explicitly deny float
|
|
allow_deny_nodes_conf = OmegaConf.create(
|
|
"""
|
|
InvokeAI:
|
|
Nodes:
|
|
allow_nodes:
|
|
- integer
|
|
- string
|
|
- float
|
|
deny_nodes:
|
|
- float
|
|
"""
|
|
)
|
|
# must parse config before importing Graph, so its nodes union uses the config
|
|
get_config.cache_clear()
|
|
conf = get_config()
|
|
get_config.cache_clear()
|
|
conf.merge_from_file(conf=allow_deny_nodes_conf, argv=[])
|
|
from invokeai.app.services.shared.graph import Graph
|
|
|
|
# confirm graph validation fails when using denied node
|
|
Graph(nodes={"1": {"id": "1", "type": "integer"}})
|
|
Graph(nodes={"1": {"id": "1", "type": "string"}})
|
|
|
|
with pytest.raises(ValidationError):
|
|
Graph(nodes={"1": {"id": "1", "type": "float"}})
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
|
|
# confirm invocations union will not have denied nodes
|
|
all_invocations = BaseInvocation.get_invocations()
|
|
|
|
has_integer = len([i for i in all_invocations if i.model_fields.get("type").default == "integer"]) == 1
|
|
has_string = len([i for i in all_invocations if i.model_fields.get("type").default == "string"]) == 1
|
|
has_float = len([i for i in all_invocations if i.model_fields.get("type").default == "float"]) == 1
|
|
|
|
assert has_integer
|
|
assert has_string
|
|
assert not has_float
|