tidy(config): add "config" to class names to differentiate from SQLite migration classes

This commit is contained in:
psychedelicious 2024-05-14 16:21:50 +10:00
parent 6946a3871f
commit 18b5aafade
4 changed files with 12 additions and 118 deletions

View File

@ -31,16 +31,16 @@ class PagingArgumentParser(argparse.ArgumentParser):
AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
class ConfigMigration:
"""Defines an individual config migration."""
from_version: Version
to_version: Version
function: MigrationFunction
function: ConfigMigrationFunction
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.

View File

@ -14,7 +14,7 @@ import yaml
from packaging.version import Version
import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -25,9 +25,9 @@ class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
def __init__(self) -> None:
self._migrations: set[MigrationEntry] = set()
self._migrations: set[ConfigMigration] = set()
def register(self, migration: MigrationEntry) -> None:
def register(self, migration: ConfigMigration) -> None:
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
@ -37,7 +37,7 @@ class ConfigMigrator:
self._migrations.add(migration)
@staticmethod
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None:
current_version = Version("3.0.0")
for m in migrations:
if current_version != m.from_version:

View File

@ -14,7 +14,7 @@ from pathlib import Path
from packaging.version import Version
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from .config_default import InvokeAIAppConfig
@ -68,7 +68,7 @@ def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
return migrated_config
config_migration_1 = MigrationEntry(
config_migration_1 = ConfigMigration(
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
)
@ -94,6 +94,6 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
return migrated_config
config_migration_2 = MigrationEntry(
config_migration_2 = ConfigMigration(
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
)

View File

@ -9,12 +9,10 @@ from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig,
InvokeAIAppConfig,
)
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config
from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -71,76 +69,6 @@ i like turtles
"""
class GoodMigrations(MigrationsBase):
methods_run: int = 0
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_1"] = True
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_2"] = True
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_3"] = True
return config_dict
class BadMigrations1(MigrationsBase):
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.2", to_version="10.0.3")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations2(MigrationsBase):
"""This one fails because the path to 10.0.2 is registered twice"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.2")
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist"""
@ -359,40 +287,6 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False
def test_migration_check() -> None:
# Test the default set of migrations
migrator = ConfigMigrator(Migrations)
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
# Test a custom set of migrations
GoodMigrations.methods_run = 0
migrator = ConfigMigrator(GoodMigrations)
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
assert new_config["schema_version"] == "10.0.2"
assert GoodMigrations.methods_run == 2
assert new_config.get("migration_2")
assert not new_config.get("migration_1")
GoodMigrations.methods_run = 0
migrator = ConfigMigrator(GoodMigrations)
new_config = migrator.run_migrations({"schema_version": "3.0.0"})
assert new_config["schema_version"] == "10.0.2"
assert GoodMigrations.methods_run == 3
assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"])
# Test a migration that should fail validation
migrator = ConfigMigrator(BadMigrations1)
with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
# Test another bad migration
migrator = ConfigMigrator(BadMigrations2)
with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
@contextmanager
def clear_config() -> Generator[None, None, None]:
try: