tests(config): set root to a tmp dir if didn't parse args

This prevents tests from triggering config related parsing on your "live" root.
This commit is contained in:
psychedelicious 2024-05-14 18:02:22 +10:00
parent 6e40142a59
commit 8b76d112be
2 changed files with 19 additions and 22 deletions

View File

@ -9,6 +9,7 @@ import shutil
from copy import deepcopy from copy import deepcopy
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable from typing import Iterable
import yaml import yaml
@ -141,6 +142,8 @@ def get_config() -> InvokeAIAppConfig:
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc. # If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse: if not InvokeAIArgs.did_parse:
tmpdir = TemporaryDirectory()
config._root = Path(tmpdir.name)
return config return config
# Set CLI args # Set CLI args

View File

@ -1,7 +1,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Any, Generator from typing import Generator
import pytest import pytest
import yaml import yaml
@ -72,12 +72,6 @@ i like turtles
""" """
@pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
def test_config_migrator_registers_migrations() -> None: def test_config_migrator_registers_migrations() -> None:
"""Test that the config migrator registers migrations.""" """Test that the config migrator registers migrations."""
migrator = ConfigMigrator() migrator = ConfigMigrator()
@ -168,14 +162,14 @@ def test_config_migrator_runs_migrations() -> None:
assert migrated_config == {"schema_version": "5.0.0"} assert migrated_config == {"schema_version": "5.0.0"}
def test_path_resolution_root_not_set(patch_rootdir: None): def test_path_resolution_root_not_set():
"""Test path resolutions when the root is not explicitly set.""" """Test path resolutions when the root is not explicitly set."""
config = InvokeAIAppConfig() config = InvokeAIAppConfig()
expected_root = InvokeAIAppConfig.find_root() expected_root = InvokeAIAppConfig.find_root()
assert config.root_path == expected_root assert config.root_path == expected_root
def test_read_config_from_file(tmp_path: Path, patch_rootdir: None): def test_read_config_from_file(tmp_path: Path):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v4_config) temp_config_file.write_text(v4_config)
@ -185,7 +179,7 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.port == 8080 assert config.port == 8080
def test_migration_1_migrates_settings(tmp_path: Path, patch_rootdir: None): def test_migration_1_migrates_settings(tmp_path: Path):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config)) migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config))
config = InvokeAIAppConfig.model_validate(migrated_config_dict) config = InvokeAIAppConfig.model_validate(migrated_config_dict)
@ -223,7 +217,7 @@ def test_migration_1_handles_legacy_conf_dir_defaults(
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
def test_load_and_migrate_backs_up_file(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_backs_up_file(tmp_path: Path):
"""Test the backup of the config file.""" """Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config) temp_config_file.write_text(v3_config)
@ -241,7 +235,7 @@ def test_migration_2_migrates_settings():
assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_failed_migrate_backup(tmp_path: Path):
"""Test the failed migration of the config file.""" """Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values) temp_config_file.write_text(v3_config_with_bad_values)
@ -254,7 +248,7 @@ def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: N
assert temp_config_file.read_text() == v3_config_with_bad_values assert temp_config_file.read_text() == v3_config_with_bad_values
def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_config) temp_config_file.write_text(invalid_config)
@ -264,7 +258,7 @@ def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir:
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config]) @pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str): def test_bails_on_config_with_unsupported_version(tmp_path: Path, config_content: str):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(config_content) temp_config_file.write_text(config_content)
@ -274,7 +268,7 @@ def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir:
load_and_migrate_config(temp_config_file) load_and_migrate_config(temp_config_file)
def test_write_config_to_file(patch_rootdir: None): def test_write_config_to_file():
"""Test writing configuration to a file, checking for correct output.""" """Test writing configuration to a file, checking for correct output."""
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
temp_config_path = Path(tmpdir) / "invokeai.yaml" temp_config_path = Path(tmpdir) / "invokeai.yaml"
@ -289,7 +283,7 @@ def test_write_config_to_file(patch_rootdir: None):
assert "port: 8080" in content assert "port: 8080" in content
def test_update_config_with_dict(patch_rootdir: None): def test_update_config_with_dict():
"""Test updating the config with a dictionary.""" """Test updating the config with a dictionary."""
config = InvokeAIAppConfig() config = InvokeAIAppConfig()
update_dict = {"host": "10.10.10.10", "port": 6060} update_dict = {"host": "10.10.10.10", "port": 6060}
@ -298,7 +292,7 @@ def test_update_config_with_dict(patch_rootdir: None):
assert config.port == 6060 assert config.port == 6060
def test_update_config_with_object(patch_rootdir: None): def test_update_config_with_object():
"""Test updating the config with another config object.""" """Test updating the config with another config object."""
config = InvokeAIAppConfig() config = InvokeAIAppConfig()
new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060) new_config = InvokeAIAppConfig(host="10.10.10.10", port=6060)
@ -307,7 +301,7 @@ def test_update_config_with_object(patch_rootdir: None):
assert config.port == 6060 assert config.port == 6060
def test_set_and_resolve_paths(patch_rootdir: None): def test_set_and_resolve_paths():
"""Test setting root and resolving paths based on it.""" """Test setting root and resolving paths based on it."""
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
config = InvokeAIAppConfig() config = InvokeAIAppConfig()
@ -316,7 +310,7 @@ def test_set_and_resolve_paths(patch_rootdir: None):
assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db" assert config.db_path == Path(tmpdir).resolve() / "databases" / "invokeai.db"
def test_singleton_behavior(patch_rootdir: None): def test_singleton_behavior():
"""Test that get_config always returns the same instance.""" """Test that get_config always returns the same instance."""
get_config.cache_clear() get_config.cache_clear()
config1 = get_config() config1 = get_config()
@ -325,13 +319,13 @@ def test_singleton_behavior(patch_rootdir: None):
get_config.cache_clear() get_config.cache_clear()
def test_default_config(patch_rootdir: None): def test_default_config():
"""Test that the default config is as expected.""" """Test that the default config is as expected."""
config = DefaultInvokeAIAppConfig() config = DefaultInvokeAIAppConfig()
assert config.host == "127.0.0.1" assert config.host == "127.0.0.1"
def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): def test_env_vars(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that environment variables are merged into the config""" """Test that environment variables are merged into the config"""
monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path)) monkeypatch.setenv("INVOKEAI_ROOT", str(tmp_path))
monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4") monkeypatch.setenv("INVOKEAI_HOST", "1.2.3.4")
@ -342,7 +336,7 @@ def test_env_vars(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path
assert config.root_path == tmp_path assert config.root_path == tmp_path
def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch, tmp_path: Path): def test_get_config_writing(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
"""Test that get_config writes the appropriate files to disk""" """Test that get_config writes the appropriate files to disk"""
# Trick the config into thinking it has already parsed args - this triggers the writing of the config file # Trick the config into thinking it has already parsed args - this triggers the writing of the config file
InvokeAIArgs.did_parse = True InvokeAIArgs.did_parse = True