tidy(config): add "config" to class names to differentiate from SQLite migration classes

This commit is contained in:
psychedelicious 2024-05-14 16:21:50 +10:00
parent 6946a3871f
commit 18b5aafade
4 changed files with 12 additions and 118 deletions

View File

@ -31,16 +31,16 @@ class PagingArgumentParser(argparse.ArgumentParser):
AppConfigDict: TypeAlias = dict[str, Any] AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass @dataclass
class MigrationEntry: class ConfigMigration:
"""Defines an individual migration.""" """Defines an individual config migration."""
from_version: Version from_version: Version
to_version: Version to_version: Version
function: MigrationFunction function: ConfigMigrationFunction
def __hash__(self) -> int: def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set. # Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.

View File

@ -14,7 +14,7 @@ import yaml
from packaging.version import Version from packaging.version import Version
import invokeai.configs as model_configs import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2 from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -25,9 +25,9 @@ class ConfigMigrator:
"""This class allows migrators to register their input and output versions.""" """This class allows migrators to register their input and output versions."""
def __init__(self) -> None: def __init__(self) -> None:
self._migrations: set[MigrationEntry] = set() self._migrations: set[ConfigMigration] = set()
def register(self, migration: MigrationEntry) -> None: def register(self, migration: ConfigMigration) -> None:
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations) migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations) migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered: if migration_from_already_registered or migration_to_already_registered:
@ -37,7 +37,7 @@ class ConfigMigrator:
self._migrations.add(migration) self._migrations.add(migration)
@staticmethod @staticmethod
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None: def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None:
current_version = Version("3.0.0") current_version = Version("3.0.0")
for m in migrations: for m in migrations:
if current_version != m.from_version: if current_version != m.from_version:

View File

@ -14,7 +14,7 @@ from pathlib import Path
from packaging.version import Version from packaging.version import Version
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from .config_default import InvokeAIAppConfig from .config_default import InvokeAIAppConfig
@ -68,7 +68,7 @@ def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
return migrated_config return migrated_config
config_migration_1 = MigrationEntry( config_migration_1 = ConfigMigration(
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400 from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
) )
@ -94,6 +94,6 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
return migrated_config return migrated_config
config_migration_2 = MigrationEntry( config_migration_2 = ConfigMigration(
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401 from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
) )

View File

@ -9,12 +9,10 @@ from pydantic import ValidationError
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.config.config_default import ( from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig, DefaultInvokeAIAppConfig,
InvokeAIAppConfig, InvokeAIAppConfig,
) )
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
from invokeai.app.services.shared.graph import Graph from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -71,76 +69,6 @@ i like turtles
""" """
class GoodMigrations(MigrationsBase):
methods_run: int = 0
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_1"] = True
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_2"] = True
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
cls.methods_run += 1
config_dict["migration_3"] = True
return config_dict
class BadMigrations1(MigrationsBase):
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.2", to_version="10.0.3")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations2(MigrationsBase):
"""This one fails because the path to 10.0.2 is registered twice"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.2")
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@pytest.fixture @pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist""" """This may be overkill since the current tests don't need the root dir to exist"""
@ -359,40 +287,6 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
InvokeAIArgs.did_parse = False InvokeAIArgs.did_parse = False
def test_migration_check() -> None:
# Test the default set of migrations
migrator = ConfigMigrator(Migrations)
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
# Test a custom set of migrations
GoodMigrations.methods_run = 0
migrator = ConfigMigrator(GoodMigrations)
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
assert new_config["schema_version"] == "10.0.2"
assert GoodMigrations.methods_run == 2
assert new_config.get("migration_2")
assert not new_config.get("migration_1")
GoodMigrations.methods_run = 0
migrator = ConfigMigrator(GoodMigrations)
new_config = migrator.run_migrations({"schema_version": "3.0.0"})
assert new_config["schema_version"] == "10.0.2"
assert GoodMigrations.methods_run == 3
assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"])
# Test a migration that should fail validation
migrator = ConfigMigrator(BadMigrations1)
with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
# Test another bad migration
migrator = ConfigMigrator(BadMigrations2)
with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
@contextmanager @contextmanager
def clear_config() -> Generator[None, None, None]: def clear_config() -> Generator[None, None, None]:
try: try: