mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(config): add "config" to class names to differentiate from SQLite migration classes
This commit is contained in:
parent
6946a3871f
commit
18b5aafade
@ -31,16 +31,16 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
|
||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationEntry:
|
||||
"""Defines an individual migration."""
|
||||
class ConfigMigration:
|
||||
"""Defines an individual config migration."""
|
||||
|
||||
from_version: Version
|
||||
to_version: Version
|
||||
function: MigrationFunction
|
||||
function: ConfigMigrationFunction
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
|
@ -14,7 +14,7 @@ import yaml
|
||||
from packaging.version import Version
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
@ -25,9 +25,9 @@ class ConfigMigrator:
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[MigrationEntry] = set()
|
||||
self._migrations: set[ConfigMigration] = set()
|
||||
|
||||
def register(self, migration: MigrationEntry) -> None:
|
||||
def register(self, migration: ConfigMigration) -> None:
|
||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||
if migration_from_already_registered or migration_to_already_registered:
|
||||
@ -37,7 +37,7 @@ class ConfigMigrator:
|
||||
self._migrations.add(migration)
|
||||
|
||||
@staticmethod
|
||||
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
|
||||
def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None:
|
||||
current_version = Version("3.0.0")
|
||||
for m in migrations:
|
||||
if current_version != m.from_version:
|
||||
|
@ -14,7 +14,7 @@ from pathlib import Path
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||
|
||||
from .config_default import InvokeAIAppConfig
|
||||
|
||||
@ -68,7 +68,7 @@ def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_1 = MigrationEntry(
|
||||
config_migration_1 = ConfigMigration(
|
||||
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
||||
)
|
||||
|
||||
@ -94,6 +94,6 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_2 = MigrationEntry(
|
||||
config_migration_2 = ConfigMigration(
|
||||
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
||||
)
|
||||
|
@ -9,12 +9,10 @@ from pydantic import ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.config.config_default import (
|
||||
CONFIG_SCHEMA_VERSION,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
)
|
||||
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
|
||||
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
|
||||
from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
@ -71,76 +69,6 @@ i like turtles
|
||||
"""
|
||||
|
||||
|
||||
class GoodMigrations(MigrationsBase):
|
||||
methods_run: int = 0
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
cls.methods_run += 1
|
||||
config_dict["migration_1"] = True
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
cls.methods_run += 1
|
||||
config_dict["migration_2"] = True
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
cls.methods_run += 1
|
||||
config_dict["migration_3"] = True
|
||||
return config_dict
|
||||
|
||||
|
||||
class BadMigrations1(MigrationsBase):
|
||||
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.2", to_version="10.0.3")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
class BadMigrations2(MigrationsBase):
|
||||
"""This one fails because the path to 10.0.2 is registered twice"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.2")
|
||||
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||
@ -359,40 +287,6 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
||||
InvokeAIArgs.did_parse = False
|
||||
|
||||
|
||||
def test_migration_check() -> None:
|
||||
# Test the default set of migrations
|
||||
migrator = ConfigMigrator(Migrations)
|
||||
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
|
||||
assert new_config is not None
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
||||
|
||||
# Test a custom set of migrations
|
||||
GoodMigrations.methods_run = 0
|
||||
migrator = ConfigMigrator(GoodMigrations)
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
assert new_config["schema_version"] == "10.0.2"
|
||||
assert GoodMigrations.methods_run == 2
|
||||
assert new_config.get("migration_2")
|
||||
assert not new_config.get("migration_1")
|
||||
|
||||
GoodMigrations.methods_run = 0
|
||||
migrator = ConfigMigrator(GoodMigrations)
|
||||
new_config = migrator.run_migrations({"schema_version": "3.0.0"})
|
||||
assert new_config["schema_version"] == "10.0.2"
|
||||
assert GoodMigrations.methods_run == 3
|
||||
assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"])
|
||||
|
||||
# Test a migration that should fail validation
|
||||
migrator = ConfigMigrator(BadMigrations1)
|
||||
with pytest.raises(ValueError):
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
|
||||
# Test another bad migration
|
||||
migrator = ConfigMigrator(BadMigrations2)
|
||||
with pytest.raises(ValueError):
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user