tests(config): test migrations directly, not via load_and_migrate_config

This commit is contained in:
psychedelicious 2024-05-14 17:16:50 +10:00
parent 00ccd73d53
commit 6e40142a59

View File

@ -4,6 +4,7 @@ from tempfile import TemporaryDirectory
from typing import Any, Generator from typing import Any, Generator
import pytest import pytest
import yaml
from packaging.version import Version from packaging.version import Version
from pydantic import ValidationError from pydantic import ValidationError
@ -14,6 +15,7 @@ from invokeai.app.services.config.config_default import (
InvokeAIAppConfig, InvokeAIAppConfig,
) )
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import migrate_v300_to_v400, migrate_v400_to_v401
from invokeai.app.services.shared.graph import Graph from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -183,12 +185,10 @@ def test_read_config_from_file(tmp_path: Path, patch_rootdir: None):
assert config.port == 8080 assert config.port == 8080
def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None): def test_migration_1_migrates_settings(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(v3_config))
temp_config_file.write_text(v3_config) config = InvokeAIAppConfig.model_validate(migrated_config_dict)
config = load_and_migrate_config(temp_config_file)
assert config.outputs_dir == Path("/some/outputs/dir") assert config.outputs_dir == Path("/some/outputs/dir")
assert config.host == "192.168.1.1" assert config.host == "192.168.1.1"
assert config.port == 8080 assert config.port == 8080
@ -212,20 +212,18 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
("full/custom/path", Path("full/custom/path"), True), ("full/custom/path", Path("full/custom/path"), True),
], ],
) )
def test_migrate_v3_legacy_conf_dir_defaults( def test_migration_1_handles_legacy_conf_dir_defaults(
tmp_path: Path, patch_rootdir: None, legacy_conf_dir: str, expected_value: Path, expected_is_set: bool legacy_conf_dir: str, expected_value: Path, expected_is_set: bool
): ):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}" config_content = f"InvokeAI:\n Paths:\n legacy_conf_dir: {legacy_conf_dir}"
temp_config_file = tmp_path / "temp_invokeai.yaml" migrated_config_dict = migrate_v300_to_v400(yaml.safe_load(config_content))
temp_config_file.write_text(config_content) config = InvokeAIAppConfig.model_validate(migrated_config_dict)
config = load_and_migrate_config(temp_config_file)
assert config.legacy_conf_dir == expected_value assert config.legacy_conf_dir == expected_value
assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set assert ("legacy_conf_dir" in config.model_fields_set) is expected_is_set
def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_backs_up_file(tmp_path: Path, patch_rootdir: None):
"""Test the backup of the config file.""" """Test the backup of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config) temp_config_file.write_text(v3_config)
@ -235,17 +233,15 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
def test_migrate_v4(tmp_path: Path, patch_rootdir: None): def test_migration_2_migrates_settings():
"""Test migration from 4.0.0 to 4.0.1""" """Test migration from 4.0.0 to 4.0.1"""
temp_config_file = tmp_path / "temp_invokeai.yaml" migrated_config_dict = migrate_v400_to_v401(yaml.safe_load(v4_config))
temp_config_file.write_text(v4_config) config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert Version(config.schema_version) == Version("4.0.1")
conf = load_and_migrate_config(temp_config_file) assert config.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
assert Version(conf.schema_version) >= Version("4.0.1")
assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
"""Test the failed migration of the config file.""" """Test the failed migration of the config file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(v3_config_with_bad_values) temp_config_file.write_text(v3_config_with_bad_values)
@ -258,7 +254,7 @@ def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
assert temp_config_file.read_text() == v3_config_with_bad_values assert temp_config_file.read_text() == v3_config_with_bad_values
def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None): def test_load_and_migrate_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
"""Test reading configuration from a file.""" """Test reading configuration from a file."""
temp_config_file = tmp_path / "temp_invokeai.yaml" temp_config_file = tmp_path / "temp_invokeai.yaml"
temp_config_file.write_text(invalid_config) temp_config_file.write_text(invalid_config)