mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(config): add "config" to class names to differentiate from SQLite migration classes
This commit is contained in:
parent
6946a3871f
commit
18b5aafade
@ -31,16 +31,16 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
|
|
||||||
AppConfigDict: TypeAlias = dict[str, Any]
|
AppConfigDict: TypeAlias = dict[str, Any]
|
||||||
|
|
||||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
ConfigMigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MigrationEntry:
|
class ConfigMigration:
|
||||||
"""Defines an individual migration."""
|
"""Defines an individual config migration."""
|
||||||
|
|
||||||
from_version: Version
|
from_version: Version
|
||||||
to_version: Version
|
to_version: Version
|
||||||
function: MigrationFunction
|
function: ConfigMigrationFunction
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
def __hash__(self) -> int:
|
||||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||||
|
@ -14,7 +14,7 @@ import yaml
|
|||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
import invokeai.configs as model_configs
|
import invokeai.configs as model_configs
|
||||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||||
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
@ -25,9 +25,9 @@ class ConfigMigrator:
|
|||||||
"""This class allows migrators to register their input and output versions."""
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._migrations: set[MigrationEntry] = set()
|
self._migrations: set[ConfigMigration] = set()
|
||||||
|
|
||||||
def register(self, migration: MigrationEntry) -> None:
|
def register(self, migration: ConfigMigration) -> None:
|
||||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||||
if migration_from_already_registered or migration_to_already_registered:
|
if migration_from_already_registered or migration_to_already_registered:
|
||||||
@ -37,7 +37,7 @@ class ConfigMigrator:
|
|||||||
self._migrations.add(migration)
|
self._migrations.add(migration)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
|
def _check_for_discontinuities(migrations: list[ConfigMigration]) -> None:
|
||||||
current_version = Version("3.0.0")
|
current_version = Version("3.0.0")
|
||||||
for m in migrations:
|
for m in migrations:
|
||||||
if current_version != m.from_version:
|
if current_version != m.from_version:
|
||||||
|
@ -14,7 +14,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
||||||
|
|
||||||
from .config_default import InvokeAIAppConfig
|
from .config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
|
|||||||
return migrated_config
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
config_migration_1 = MigrationEntry(
|
config_migration_1 = ConfigMigration(
|
||||||
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -94,6 +94,6 @@ def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
|
|||||||
return migrated_config
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
config_migration_2 = MigrationEntry(
|
config_migration_2 = ConfigMigration(
|
||||||
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
||||||
)
|
)
|
||||||
|
@ -9,12 +9,10 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
from invokeai.app.services.config.config_default import (
|
from invokeai.app.services.config.config_default import (
|
||||||
CONFIG_SCHEMA_VERSION,
|
|
||||||
DefaultInvokeAIAppConfig,
|
DefaultInvokeAIAppConfig,
|
||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
|
from invokeai.app.services.config.config_migrate import get_config, load_and_migrate_config
|
||||||
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
|
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
@ -71,76 +69,6 @@ i like turtles
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GoodMigrations(MigrationsBase):
|
|
||||||
methods_run: int = 0
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, migrator: ConfigMigrator) -> None:
|
|
||||||
"""Define migrations to perform."""
|
|
||||||
|
|
||||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
|
||||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
cls.methods_run += 1
|
|
||||||
config_dict["migration_1"] = True
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
|
||||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
cls.methods_run += 1
|
|
||||||
config_dict["migration_2"] = True
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
|
||||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
cls.methods_run += 1
|
|
||||||
config_dict["migration_3"] = True
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
class BadMigrations1(MigrationsBase):
|
|
||||||
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, migrator: ConfigMigrator) -> None:
|
|
||||||
"""Define migrations to perform."""
|
|
||||||
|
|
||||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
|
||||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
|
||||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.2", to_version="10.0.3")
|
|
||||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
class BadMigrations2(MigrationsBase):
|
|
||||||
"""This one fails because the path to 10.0.2 is registered twice"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, migrator: ConfigMigrator) -> None:
|
|
||||||
"""Define migrations to perform."""
|
|
||||||
|
|
||||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
|
||||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
|
||||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
|
||||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
@migrator.register(from_version="10.0.0", to_version="10.0.2")
|
|
||||||
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||||
@ -359,40 +287,6 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
|||||||
InvokeAIArgs.did_parse = False
|
InvokeAIArgs.did_parse = False
|
||||||
|
|
||||||
|
|
||||||
def test_migration_check() -> None:
|
|
||||||
# Test the default set of migrations
|
|
||||||
migrator = ConfigMigrator(Migrations)
|
|
||||||
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
|
|
||||||
assert new_config is not None
|
|
||||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
|
||||||
|
|
||||||
# Test a custom set of migrations
|
|
||||||
GoodMigrations.methods_run = 0
|
|
||||||
migrator = ConfigMigrator(GoodMigrations)
|
|
||||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
|
||||||
assert new_config["schema_version"] == "10.0.2"
|
|
||||||
assert GoodMigrations.methods_run == 2
|
|
||||||
assert new_config.get("migration_2")
|
|
||||||
assert not new_config.get("migration_1")
|
|
||||||
|
|
||||||
GoodMigrations.methods_run = 0
|
|
||||||
migrator = ConfigMigrator(GoodMigrations)
|
|
||||||
new_config = migrator.run_migrations({"schema_version": "3.0.0"})
|
|
||||||
assert new_config["schema_version"] == "10.0.2"
|
|
||||||
assert GoodMigrations.methods_run == 3
|
|
||||||
assert all(new_config[x] for x in ["migration_1", "migration_2", "migration_3"])
|
|
||||||
|
|
||||||
# Test a migration that should fail validation
|
|
||||||
migrator = ConfigMigrator(BadMigrations1)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
|
||||||
|
|
||||||
# Test another bad migration
|
|
||||||
migrator = ConfigMigrator(BadMigrations2)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def clear_config() -> Generator[None, None, None]:
|
def clear_config() -> Generator[None, None, None]:
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user