tests(config): test migrations directly, not via load_and_migrate_config

This commit is contained in:
psychedelicious 2024-05-14 17:16:50 +10:00
parent 00ccd73d53
commit 6e40142a59

View File

@ -4,6 +4,7 @@ from tempfile import TemporaryDirectory
from typing import Any, Generator
import pytest
import yaml
from packaging.version import Version
from pydantic import ValidationError
@ -14,6 +15,7 @@ from invokeai.app.services.config.config_default import (
InvokeAIAppConfig,
)
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import migrate_v300_to_v400, migrate_v400_to_v401
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -183,12 +185,10 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.port == 8080
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
def test_migration_1_migrates_settings(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
config = load_and_migrate_config(temp_config_file)
migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config))
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert config.outputs_dir == Path("/some/outputs/dir")
assert config.host == "192.168.1.1"
assert config.port == 8080
@ -212,20 +212,18 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
("full/custom/path", Path("full/custom/path"), True),
],
)
def test_migrate_v3_legacy_conf_dir_defaults(
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
def test_migration_1_handles_legacy_conf_dir_defaults(
legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
):
"""Test reading configuration from a file."""
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(config_content)
config = load_and_migrate_config(temp_config_file)
migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(config_content))
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert config.legacy_conf_dir == expected_value
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
def test_load_and_migrate_backs_up_file(tmp_path: Path, patch_rootdir: None):
"""Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config)
@ -235,17 +233,15 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_migrate_v4(tmp_path: Path, patch_rootdir: None):
def test_migration_2_migrates_settings():
"""Test migration from 4.0.0 to 4.0.1"""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v4_config)
conf = load_and_migrate_config(temp_config_file)
assert Version(conf.schema_version) >= Version("4.0.1")
assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
migrated_config_dict = migrate_v400_to_v401(yaml.safe_load(v4_config))
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert Version(config.schema_version) == Version("4.0.1")
assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
"""Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values)
@ -258,7 +254,7 @@ def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_config)