mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
check for strictly contiguous from_version->to_version ranges
This commit is contained in:
parent
ab9ebef345
commit
6eaed9a9cb
@ -444,7 +444,7 @@ def get_config() -> InvokeAIAppConfig:
|
|||||||
####################################################
|
####################################################
|
||||||
|
|
||||||
|
|
||||||
@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0")
|
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Migrate a v3 config dictionary to a current config object.
|
"""Migrate a v3 config dictionary to a current config object.
|
||||||
|
|
||||||
|
@ -48,11 +48,14 @@ class ConfigMigrator:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
|
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
|
||||||
current_version = Version("0.0.0")
|
current_version = Version("3.0.0")
|
||||||
for m in migrations:
|
for m in migrations:
|
||||||
if current_version > m.from_version:
|
if current_version != m.from_version:
|
||||||
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
|
raise ValueError(
|
||||||
|
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
|
||||||
|
)
|
||||||
|
current_version = m.to_version
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
||||||
@ -68,9 +71,9 @@ class ConfigMigrator:
|
|||||||
ValueError exception.
|
ValueError exception.
|
||||||
"""
|
"""
|
||||||
# Sort migrations by version number and raise a ValueError if
|
# Sort migrations by version number and raise a ValueError if
|
||||||
# any version range overlaps are detected. Discontinuities are ok
|
# any version range overlaps are detected.
|
||||||
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
||||||
cls._check_for_overlaps(sorted_migrations)
|
cls._check_for_discontinuities(sorted_migrations)
|
||||||
|
|
||||||
if "InvokeAI" in config_dict:
|
if "InvokeAI" in config_dict:
|
||||||
version = Version("3.0.0")
|
version = Version("3.0.0")
|
||||||
@ -78,7 +81,7 @@ class ConfigMigrator:
|
|||||||
version = Version(config_dict["schema_version"])
|
version = Version(config_dict["schema_version"])
|
||||||
|
|
||||||
for migration in sorted_migrations:
|
for migration in sorted_migrations:
|
||||||
if version >= migration.from_version and version < migration.to_version:
|
if version == migration.from_version and version < migration.to_version:
|
||||||
config_dict = migration.function(config_dict)
|
config_dict = migration.function(config_dict)
|
||||||
version = migration.to_version
|
version = migration.to_version
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from packaging.version import Version
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import (
|
from invokeai.app.services.config.config_default import (
|
||||||
@ -18,12 +19,13 @@ invalid_v4_0_1_config = """
|
|||||||
schema_version: 4.0.1
|
schema_version: 4.0.1
|
||||||
|
|
||||||
host: "192.168.1.1"
|
host: "192.168.1.1"
|
||||||
port: 8080
|
port: "ice cream"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
v4_config = """
|
v4_config = """
|
||||||
schema_version: 4.0.0
|
schema_version: 4.0.0
|
||||||
|
|
||||||
|
precision: autocast
|
||||||
host: "192.168.1.1"
|
host: "192.168.1.1"
|
||||||
port: 8080
|
port: 8080
|
||||||
"""
|
"""
|
||||||
@ -141,6 +143,16 @@ def test_migrate_v3_backup(tmp_path: Path, patch_rootdir: None):
|
|||||||
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
assert temp_config_file.with_suffix(".yaml.bak").read_text() == v3_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_migrate_v4(tmp_path: Path, patch_rootdir: None):
|
||||||
|
"""Test migration from 4.0.0 to 4.0.1"""
|
||||||
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
|
temp_config_file.write_text(v4_config)
|
||||||
|
|
||||||
|
conf = load_and_migrate_config(temp_config_file)
|
||||||
|
assert Version(conf.schema_version) >= Version("4.0.1")
|
||||||
|
assert conf.precision == "auto" # we expect 'autocast' to be replaced with 'auto' during 4.0.1 migration
|
||||||
|
|
||||||
|
|
||||||
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
def test_failed_migrate_backup(tmp_path: Path, patch_rootdir: None):
|
||||||
"""Test the failed migration of the config file."""
|
"""Test the failed migration of the config file."""
|
||||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
@ -162,13 +174,15 @@ def test_bails_on_invalid_config(tmp_path: Path, patch_rootdir: None):
|
|||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
load_and_migrate_config(temp_config_file)
|
load_and_migrate_config(temp_config_file)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
|
@pytest.mark.parametrize("config_content", [invalid_v5_config, invalid_v4_0_1_config])
|
||||||
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str):
|
def test_bails_on_config_with_unsupported_version(tmp_path: Path, patch_rootdir: None, config_content: str):
|
||||||
"""Test reading configuration from a file."""
|
"""Test reading configuration from a file."""
|
||||||
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
temp_config_file = tmp_path / "temp_invokeai.yaml"
|
||||||
temp_config_file.write_text(config_content)
|
temp_config_file.write_text(config_content)
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="Invalid schema version"):
|
# with pytest.raises(RuntimeError, match="Invalid schema version"):
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
load_and_migrate_config(temp_config_file)
|
load_and_migrate_config(temp_config_file)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user